Skip to content

Commit 6d52669

Browse files
committed
Added expiration timeout and handling for models
Signed-off-by: Michael Engel <[email protected]>
1 parent 0d47a2f commit 6d52669

File tree

5 files changed

+58
-10
lines changed

5 files changed

+58
-10
lines changed

ramalama/daemon/daemon.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import signal
55
import socketserver
66
import threading
7+
from datetime import datetime, timedelta
78

89
from ramalama.daemon.handler.ramalama import RamalamaHandler
910
from ramalama.daemon.logging import configure_logger, logger
@@ -17,35 +18,66 @@ def __init__(self, server: "RamalamaServer") -> None:
1718
def handle_kill(self, signum, frame):
1819
self.server.shutdown()
1920

21+
def handle_alarm(self, signum, frame):
22+
# check for expiration of all models, stopping them if necessary and
23+
# stop shutdown server if no models are running
24+
self.server.check_model_expiration()
25+
if not self.server.model_runner.managed_models:
26+
self.server.shutdown()
27+
return
28+
29+
# register alarm again for next check
30+
signal.alarm(self.server.idle_check_interval.seconds)
31+
2032
def __enter__(self):
2133
signal.signal(signal.SIGINT, self.handle_kill)
2234
signal.signal(signal.SIGTERM, self.handle_kill)
2335

36+
# set initial idle check to 300s == 5min to prevent service from stopping
37+
# right afer being started
38+
signal.signal(signal.SIGALRM, self.handle_alarm)
39+
signal.alarm(300)
40+
2441
def __exit__(self, type, value, traceback):
2542
pass
2643

2744

2845
class RamalamaServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
2946

30-
def __init__(self, host: str, port: int, model_store_path: str, bind_and_activate=True):
47+
def __init__(
48+
self, host: str, port: int, model_store_path: str, idle_check_interval: timedelta, bind_and_activate=True
49+
):
3150
# Do not pass a RequestHandlerClass here, we will create a custom handler in finish_request
3251
super().__init__((host, port), None, bind_and_activate)
3352

3453
self.model_store_path: str = model_store_path
3554
self.model_runner: ModelRunner = ModelRunner()
55+
self.idle_check_interval = idle_check_interval
3656

3757
self.allow_reuse_address = True
3858

3959
def finish_request(self, request, client_address):
4060
RamalamaHandler(self.model_store_path, self.model_runner, request, client_address, self)
4161

62+
def check_model_expiration(self):
63+
curr_time = datetime.now()
64+
for name, m in self.model_runner.managed_models.items():
65+
if m.expiration_date > curr_time:
66+
continue
67+
68+
try:
69+
logger.error(f"Stopping expired model '{name}'...")
70+
self.model_runner.stop_model(m.id)
71+
except Exception as e:
72+
logger.error(f"Failed to stop expired model '{name}': {e}")
73+
4274
def shutdown(self):
4375
logger.info("Shutting down ramalama daemon...")
4476

4577
for name, managed_model in self.model_runner.managed_models.items():
4678
try:
4779
logger.info(f"Stopping model runner {name}...")
48-
managed_model.stop()
80+
self.model_runner.stop_model(managed_model.id)
4981
except Exception as e:
5082
logger.error(f"Error stopping model runner {name}: {e}")
5183

@@ -64,7 +96,7 @@ def parse_args():
6496
def run(host: str = "0.0.0.0", port: int = 8080, model_store_path: str = "/models"):
6597
configure_logger("DEBUG")
6698
logger.info(f"Starting Ramalama daemon on {host}:{port}...")
67-
with RamalamaServer(host, port, model_store_path) as httpd:
99+
with RamalamaServer(host, port, model_store_path, timedelta(seconds=10)) as httpd:
68100
with ShutdownHandler(httpd):
69101
server_thread = threading.Thread(target=httpd.serve_forever, daemon=True)
70102
server_thread.start()

ramalama/daemon/handler/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import http.server
22
import json
33
from abc import ABC, abstractmethod
4-
from datetime import datetime, timedelta
54

65
from ramalama.daemon.dto.model import RunningModelResponse, running_model_list_to_dict
76
from ramalama.daemon.service.model_runner import ModelRunner
@@ -40,7 +39,6 @@ def _handle_get_running_models(self, handler: http.server.SimpleHTTPRequestHandl
4039
full_model_name = (
4140
f"{m.model.model_type}://{m.model.model_organization}/{m.model.model_name}:{m.model.model_tag}"
4241
)
43-
expiration = datetime.now() + timedelta(minutes=5)
4442
models.append(
4543
RunningModelResponse(
4644
id=m.id,
@@ -49,7 +47,7 @@ def _handle_get_running_models(self, handler: http.server.SimpleHTTPRequestHandl
4947
tag=m.model.model_tag,
5048
source=m.model.type,
5149
model=full_model_name,
52-
expires_at=expiration.strftime("%Y-%m-%dT%H:%M:%SZ"),
50+
expires_at=m.expiration_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
5351
size_vram=0,
5452
digest=m.id.replace("sha-", ""),
5553
cmd=" ".join(m.run_cmd),

ramalama/daemon/handler/daemon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import http.server
22
import json
3-
from datetime import datetime
3+
from datetime import datetime, timedelta
44

55
from ramalama.arg_types import StoreArgs
66
from ramalama.common import generate_sha256
@@ -123,7 +123,7 @@ def _handle_post_serve(self, handler: http.server.SimpleHTTPRequestHandler):
123123

124124
logger.info(f"Starting model runner for {serve_request.model_name} with command: {cmd}")
125125
id = ModelRunner.generate_model_id(model.model_name, model.model_tag, model.model_organization)
126-
model = ManagedModel(id, model, cmd, port)
126+
model = ManagedModel(id, model, cmd, port, timedelta(seconds=30))
127127
self.model_runner.add_model(model)
128128
self.model_runner.start_model(id)
129129

ramalama/daemon/handler/proxy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def _forward_request(self, handler: http.server.SimpleHTTPRequestHandler, is_ref
7474
return
7575

7676
managed_model = self.model_runner.managed_models[model_id]
77+
managed_model.update_expiration_date()
78+
7779
target_url = f"http://0.0.0.0:{managed_model.port}{path}"
7880
method = handler.command
7981
headers = handler.headers

ramalama/daemon/service/model_runner.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import subprocess
2+
from datetime import datetime, timedelta
23
from typing import Optional
34

45
from ramalama.common import generate_sha256
@@ -7,16 +8,28 @@
78

89
class ManagedModel:
910

10-
def __init__(self, id: str, model: CLASS_MODEL_TYPES, run_cmd: list[str], port: int):
11+
def __init__(
12+
self,
13+
id: str,
14+
model: CLASS_MODEL_TYPES,
15+
run_cmd: list[str],
16+
port: int,
17+
expires_after: timedelta = timedelta(minutes=5),
18+
):
1119
self.id = id
1220
self.model = model
1321
self.run_cmd: list[str] = run_cmd
1422
self.port: str = port
23+
24+
self.expires_after = expires_after
25+
self.expiration_date: Optional[datetime] = None
26+
1527
self.process: Optional[subprocess.Popen] = None
1628

1729
def start(self):
1830
if self.process is not None:
1931
raise RuntimeError(f"Model {self.id} is already running.")
32+
self.update_expiration_date()
2033
self.process = subprocess.Popen(self.run_cmd)
2134

2235
def stop(self):
@@ -25,6 +38,9 @@ def stop(self):
2538
self.process.wait()
2639
self.process = None
2740

41+
def update_expiration_date(self):
42+
self.expiration_date = datetime.now() + self.expires_after
43+
2844

2945
class ModelRunner:
3046

@@ -47,7 +63,7 @@ def next_available_port(self) -> int:
4763

4864
@staticmethod
4965
def generate_model_id(model_name: str, model_tag: str, model_organization: str) -> str:
50-
return generate_sha256(f"{model_name}-{model_tag}-{model_organization}")
66+
return generate_sha256(f"{model_name}-{model_tag}-{model_organization}", with_sha_prefix=False)
5167

5268
def add_model(self, model: ManagedModel):
5369
if model.id in self._models:

0 commit comments

Comments
 (0)