4
4
import signal
5
5
import socketserver
6
6
import threading
7
+ from datetime import datetime , timedelta
7
8
8
9
from ramalama .daemon .handler .ramalama import RamalamaHandler
9
10
from ramalama .daemon .logging import configure_logger , logger
@@ -17,35 +18,66 @@ def __init__(self, server: "RamalamaServer") -> None:
17
18
def handle_kill (self , signum , frame ):
18
19
self .server .shutdown ()
19
20
21
+ def handle_alarm (self , signum , frame ):
22
+ # check for expiration of all models, stopping them if necessary and
23
+ # stop shutdown server if no models are running
24
+ self .server .check_model_expiration ()
25
+ if not self .server .model_runner .managed_models :
26
+ self .server .shutdown ()
27
+ return
28
+
29
+ # register alarm again for next check
30
+ signal .alarm (self .server .idle_check_interval .seconds )
31
+
20
32
def __enter__ (self ):
21
33
signal .signal (signal .SIGINT , self .handle_kill )
22
34
signal .signal (signal .SIGTERM , self .handle_kill )
23
35
36
+ # set initial idle check to 300s == 5min to prevent service from stopping
37
+ # right afer being started
38
+ signal .signal (signal .SIGALRM , self .handle_alarm )
39
+ signal .alarm (300 )
40
+
24
41
def __exit__ (self , type , value , traceback ):
25
42
pass
26
43
27
44
28
45
class RamalamaServer (socketserver .ThreadingMixIn , socketserver .TCPServer ):
29
46
30
- def __init__ (self , host : str , port : int , model_store_path : str , bind_and_activate = True ):
47
+ def __init__ (
48
+ self , host : str , port : int , model_store_path : str , idle_check_interval : timedelta , bind_and_activate = True
49
+ ):
31
50
# Do not pass a RequestHandlerClass here, we will create a custom handler in finish_request
32
51
super ().__init__ ((host , port ), None , bind_and_activate )
33
52
34
53
self .model_store_path : str = model_store_path
35
54
self .model_runner : ModelRunner = ModelRunner ()
55
+ self .idle_check_interval = idle_check_interval
36
56
37
57
self .allow_reuse_address = True
38
58
39
59
def finish_request (self , request , client_address ):
40
60
RamalamaHandler (self .model_store_path , self .model_runner , request , client_address , self )
41
61
62
+ def check_model_expiration (self ):
63
+ curr_time = datetime .now ()
64
+ for name , m in self .model_runner .managed_models .items ():
65
+ if m .expiration_date > curr_time :
66
+ continue
67
+
68
+ try :
69
+ logger .error (f"Stopping expired model '{ name } '..." )
70
+ self .model_runner .stop_model (m .id )
71
+ except Exception as e :
72
+ logger .error (f"Failed to stop expired model '{ name } ': { e } " )
73
+
42
74
def shutdown (self ):
43
75
logger .info ("Shutting down ramalama daemon..." )
44
76
45
77
for name , managed_model in self .model_runner .managed_models .items ():
46
78
try :
47
79
logger .info (f"Stopping model runner { name } ..." )
48
- managed_model . stop ( )
80
+ self . model_runner . stop_model ( managed_model . id )
49
81
except Exception as e :
50
82
logger .error (f"Error stopping model runner { name } : { e } " )
51
83
@@ -64,7 +96,7 @@ def parse_args():
64
96
def run (host : str = "0.0.0.0" , port : int = 8080 , model_store_path : str = "/models" ):
65
97
configure_logger ("DEBUG" )
66
98
logger .info (f"Starting Ramalama daemon on { host } :{ port } ..." )
67
- with RamalamaServer (host , port , model_store_path ) as httpd :
99
+ with RamalamaServer (host , port , model_store_path , timedelta ( seconds = 10 ) ) as httpd :
68
100
with ShutdownHandler (httpd ):
69
101
server_thread = threading .Thread (target = httpd .serve_forever , daemon = True )
70
102
server_thread .start ()
0 commit comments