@@ -88,8 +88,36 @@ def launch_server_process_and_send_one_request(
88
88
headers = {
89
89
"Content-Type" : "application/json; charset=utf-8" ,
90
90
}
91
- response = requests .get (f"{ base_url } /v1/models" , headers = headers )
91
+ if server_args .node_rank == 0 :
92
+ response = requests .get (f"{ base_url } /v1/models" , headers = headers )
93
+ else :
94
+ # This http api is created by launch_dummy_health_check_server for none-rank0 node.
95
+ response = requests .get (f"{ base_url } /health" , headers = headers )
92
96
if response .status_code == 200 :
97
+ # Rank-0 node send a request to sync with other node and then return.
98
+ if server_args .node_rank == 0 :
99
+ response = requests .post (
100
+ f"{ base_url } /generate" ,
101
+ json = {
102
+ "input_ids" : [0 , 1 , 2 , 3 ],
103
+ "sampling_params" : {
104
+ "max_new_tokens" : 8 ,
105
+ "temperature" : 0 ,
106
+ },
107
+ },
108
+ timeout = 600 ,
109
+ )
110
+ if response .status_code != 200 :
111
+ error = response .json ()
112
+ raise RuntimeError (f"Sync request failed: { error } " )
113
+ # Other nodes should wait for the exit signal from Rank-0 node.
114
+ else :
115
+ start_time_waiting = time .time ()
116
+ while proc .is_alive ():
117
+ if time .time () - start_time_waiting < timeout :
118
+ time .sleep (10 )
119
+ else :
120
+ raise TimeoutError ("Waiting for main node timeout!" )
93
121
return proc
94
122
except requests .RequestException :
95
123
pass
@@ -122,10 +150,19 @@ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
122
150
123
151
proc = launch_server_process_and_send_one_request (server_args , compile_args )
124
152
125
- kill_process_tree (proc .pid )
126
-
127
153
print ("\n DeepGEMM Kernels compilation finished successfully." )
128
154
155
+ # Sleep for safety
156
+ time .sleep (10 )
157
+ if proc .is_alive ():
158
+ # This is the rank0 node.
159
+ kill_process_tree (proc .pid )
160
+ else :
161
+ try :
162
+ kill_process_tree (proc .pid )
163
+ except Exception :
164
+ pass
165
+
129
166
130
167
if __name__ == "__main__" :
131
168
parser = argparse .ArgumentParser ()
0 commit comments