Skip to content

Commit 173ebea

Browse files
authored
Merge pull request #1189 from engelmi/set-model-store-as-default
Set model store as default
2 parents de8eeae + d8183a8 commit 173ebea

16 files changed

+312
-157
lines changed

bin/ramalama

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ def main(args):
6969
ramalama.perror("Error: " + str(e).strip("'\""))
7070
sys.exit(exit_code)
7171

72+
try:
73+
from ramalama.migrate import ModelStoreImport
74+
75+
ModelStoreImport(args.store).import_all()
76+
except Exception as ex:
77+
print(f"Failed to import models to new store: {ex}")
78+
7279
# Process CLI
7380
try:
7481
args.func(args)

ramalama/cli.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ def parse_arguments(parser):
219219

220220
def post_parse_setup(args):
221221
"""Perform additional setup after parsing arguments."""
222-
mkdirs(args.store)
223222
if hasattr(args, "MODEL") and args.subcommand != "rm":
224223
resolved_model = shortnames.resolve(args.MODEL)
225224
if resolved_model:
@@ -283,23 +282,6 @@ def logout_cli(args):
283282
return model.logout(args)
284283

285284

286-
def mkdirs(store):
287-
# List of directories to create
288-
directories = [
289-
"models/huggingface",
290-
"repos/huggingface",
291-
"models/oci",
292-
"repos/oci",
293-
"models/ollama",
294-
"repos/ollama",
295-
]
296-
297-
# Create each directory
298-
for directory in directories:
299-
full_path = os.path.join(store, directory)
300-
os.makedirs(full_path, exist_ok=True)
301-
302-
303285
def human_duration(d):
304286
if d < 1:
305287
return "Less than a second"
@@ -624,14 +606,14 @@ def convert_cli(args):
624606
raise ValueError("convert command cannot be run with the --nocontainer option.")
625607

626608
target = args.TARGET
627-
source = _get_source(args)
609+
source_model = _get_source_model(args)
628610

629611
tgt = shortnames.resolve(target)
630612
if not tgt:
631613
tgt = target
632614

633615
model = ModelFactory(tgt, args).create_oci()
634-
model.convert(source, args)
616+
model.convert(source_model, args)
635617

636618

637619
def push_parser(subparsers):
@@ -668,44 +650,38 @@ def push_parser(subparsers):
668650
parser.set_defaults(func=push_cli)
669651

670652

671-
def _get_source(args):
672-
if os.path.exists(args.SOURCE):
673-
return args.SOURCE
674-
653+
def _get_source_model(args):
675654
src = shortnames.resolve(args.SOURCE)
676655
if not src:
677656
src = args.SOURCE
678657
smodel = New(src, args)
679658
if smodel.type == "OCI":
680659
raise ValueError("converting from an OCI based image %s is not supported" % src)
681660
if not smodel.exists(args):
682-
return smodel.pull(args)
683-
return smodel.model_path(args)
661+
smodel.pull(args)
662+
return smodel
684663

685664

686665
def push_cli(args):
687-
if args.TARGET:
688-
target = args.TARGET
689-
source = _get_source(args)
690-
else:
691-
target = args.SOURCE
692-
source = args.SOURCE
693666

694-
tgt = shortnames.resolve(target)
695-
if not tgt:
696-
tgt = target
667+
source_model = _get_source_model(args)
668+
target = args.SOURCE
669+
if args.TARGET:
670+
target = shortnames.resolve(args.TARGET)
671+
if not target:
672+
target = args.TARGET
673+
target_model = New(target, args)
697674

698675
try:
699-
model = New(tgt, args)
700-
model.push(source, args)
676+
target_model.push(source_model, args)
701677
except NotImplementedError as e:
702678
for mtype in MODEL_TYPES:
703-
if tgt.startswith(mtype + "://"):
679+
if target.startswith(mtype + "://"):
704680
raise e
705681
try:
706682
# attempt to push as a container image
707-
m = ModelFactory(tgt, args).create_oci()
708-
m.push(source, args)
683+
m = ModelFactory(target, args).create_oci()
684+
m.push(source_model, args)
709685
except Exception as e1:
710686
if args.debug:
711687
print(e1)

ramalama/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def load_config_defaults(config: Dict[str, Any]):
101101
config.setdefault('store', get_store())
102102
config.setdefault('temp', "0.8")
103103
config.setdefault('transport', "ollama")
104-
config.setdefault('use_model_store', False)
104+
config.setdefault('use_model_store', True)
105105

106106

107107
class Config(ChainMap):

ramalama/huggingface.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def url_pull(self, args, model_path, directory_path):
313313
os.symlink(relative_target_path, model_path)
314314
return model_path
315315

316-
def push(self, source, args):
316+
def push(self, _, args):
317317
if not self.hf_cli_available:
318318
raise NotImplementedError(missing_huggingface)
319319
proc = run_cmd(
@@ -356,15 +356,18 @@ def _collect_cli_files(self, tempdir: str) -> tuple[str, list[HuggingfaceCLIFile
356356
if entry.lower() == "readme.md":
357357
snapshot_hash = sha256
358358
continue
359-
files.append(
360-
HuggingfaceCLIFile(
361-
url=entry_path,
362-
header={},
363-
hash=sha256,
364-
type=SnapshotFileType.Other,
365-
name=entry,
366-
)
359+
360+
hf_file = HuggingfaceCLIFile(
361+
url=entry_path,
362+
header={},
363+
hash=sha256,
364+
type=SnapshotFileType.Other,
365+
name=entry,
367366
)
367+
# try to identify the model file in the pulled repo
368+
if entry.endswith(".safetensors") or entry.endswith(".gguf"):
369+
hf_file.type = SnapshotFileType.Model
370+
files.append(hf_file)
368371

369372
return snapshot_hash, files
370373

@@ -376,7 +379,7 @@ def _pull_with_model_store(self, debug: bool = False):
376379

377380
try:
378381
# Fetch the SHA-256 checksum of model from the API and use as snapshot hash
379-
snapshot_hash = f"sha256:{fetch_checksum_from_api(self.organization, self.name)}"
382+
snapshot_hash = f"sha256:{fetch_checksum_from_api(organization, name)}"
380383

381384
hf_repo = HuggingfaceRepository(name, organization)
382385
files = hf_repo.get_file_list(cached_files, snapshot_hash)
@@ -387,7 +390,11 @@ def _pull_with_model_store(self, debug: bool = False):
387390
raise KeyError(f"Failed to pull model: {str(e)}")
388391

389392
# Cleanup previously created snapshot
390-
self.store.remove_snapshot(tag)
393+
try:
394+
self.store.remove_snapshot(tag)
395+
except Exception:
396+
# ignore any error when removing snapshot
397+
pass
391398

392399
# Create temporary directory for downloading via huggingface-cli
393400
with tempfile.TemporaryDirectory() as tempdir:

ramalama/migrate.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#
2+
# Utility for migrating from the old to the new model store
3+
#
4+
import os
5+
import shutil
6+
7+
from ramalama.common import generate_sha256
8+
from ramalama.model import MODEL_TYPES
9+
from ramalama.model_factory import ModelFactory
10+
from ramalama.model_store import GlobalModelStore, SnapshotFile, SnapshotFileType
11+
12+
13+
class dotdict(dict):
14+
"""dot.notation access to dictionary attributes"""
15+
16+
__getattr__ = dict.get
17+
__setattr__ = dict.__setitem__
18+
__delattr__ = dict.__delitem__
19+
20+
21+
class ModelStoreImport:
22+
23+
def __init__(self, store_path: str):
24+
self.store_path = store_path
25+
self._old_model_path = os.path.join(store_path, "models")
26+
self._old_repo_path = os.path.join(store_path, "repos")
27+
self._global_store = GlobalModelStore(self.store_path)
28+
29+
class LocalModelFile(SnapshotFile):
30+
31+
def __init__(
32+
self, url, header, hash, name, should_show_progress=False, should_verify_checksum=False, required=True
33+
):
34+
super().__init__(
35+
url, header, hash, name, SnapshotFileType.Model, should_show_progress, should_verify_checksum, required
36+
)
37+
38+
def download(self, blob_file_path, snapshot_dir):
39+
if not os.path.exists(self.url):
40+
raise FileNotFoundError(f"No such file: '{self.url}'")
41+
# moving from the local location to blob directory so the model store "owns" the data
42+
shutil.copy(self.url, blob_file_path)
43+
return os.path.relpath(blob_file_path, start=snapshot_dir)
44+
45+
def import_all(self):
46+
if not os.path.exists(self._old_model_path):
47+
return
48+
49+
print("Starting importing AI models to new store...")
50+
for root, _, files in os.walk(self._old_model_path):
51+
if not files:
52+
continue
53+
54+
try:
55+
# reconstruct the cli model input
56+
model = root.replace(self._old_model_path, "")
57+
model = model.replace(os.sep, "", 1)
58+
if model.startswith("file/"):
59+
model = model.replace("/", ":///", 1)
60+
else:
61+
model = model.replace("/", "://", 1)
62+
63+
if model in MODEL_TYPES:
64+
model = ""
65+
66+
for file in files:
67+
m = ModelFactory(
68+
os.path.join(model, file),
69+
args=dotdict(
70+
{
71+
"store": self.store_path,
72+
"use_model_store": True,
73+
"engine": "podman",
74+
"container": True,
75+
}
76+
),
77+
).create()
78+
_, model_tag, _ = m.extract_model_identifiers()
79+
_, _, all = m.store.get_cached_files(model_tag)
80+
if all:
81+
print(f"Already imported: {root}/{file}")
82+
continue
83+
84+
snapshot_hash = generate_sha256(file)
85+
old_model_path = os.path.join(root, file)
86+
87+
files: list[SnapshotFile] = []
88+
files.append(
89+
ModelStoreImport.LocalModelFile(
90+
url=old_model_path,
91+
header={},
92+
hash=snapshot_hash,
93+
name=file,
94+
required=True,
95+
)
96+
)
97+
98+
m.store.new_snapshot(model_tag, snapshot_hash, files)
99+
print(f"Imported {old_model_path} -> {m.store.get_snapshot_file_path(snapshot_hash, file)}")
100+
101+
except Exception as ex:
102+
print(f"Failed to import {root}: {ex}")
103+
104+
if os.path.exists(self._old_model_path):
105+
try:
106+
shutil.rmtree(self._old_model_path)
107+
except Exception as ex:
108+
print(f"Failed to remove old model directory: {ex}")
109+
if os.path.exists(self._old_repo_path):
110+
try:
111+
shutil.rmtree(self._old_repo_path)
112+
except Exception as ex:
113+
print(f"Failed to remove old blob directory: {ex}")

0 commit comments

Comments
 (0)