Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\
load_custom_from_hf
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf, \
load_state_dict_from_path, load_custom_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
Expand Down Expand Up @@ -45,6 +45,9 @@ def _resolve_pretrained_source(pretrained_cfg):
load_from = 'hf-hub'
assert hf_hub_id
pretrained_loc = hf_hub_id
elif cfg_source == 'local-dir':
load_from = 'local-dir'
pretrained_loc = pretrained_file
else:
# default source == timm or unspecified
if pretrained_sd:
Expand Down Expand Up @@ -211,6 +214,13 @@ def load_pretrained(
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
else:
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
elif load_from == 'local-dir':
_logger.info(f'Loading pretrained weights from local directory ({pretrained_loc})')
pretrained_path = Path(pretrained_loc)
if pretrained_path.is_dir():
state_dict = load_state_dict_from_path(pretrained_path)
else:
RuntimeError(f"Specified path is not a directory: {pretrained_loc}")
else:
model_name = pretrained_cfg.get('architecture', 'this model')
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
Expand Down
33 changes: 21 additions & 12 deletions timm/models/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from timm.layers import set_layer_config
from ._helpers import load_checkpoint
from ._hub import load_model_config_from_hf
from ._hub import load_model_config_from_hf, load_model_config_from_path
from ._pretrained import PretrainedCfg
from ._registry import is_model, model_entrypoint, split_model_name_tag

Expand All @@ -18,13 +18,15 @@ def parse_model_name(model_name: str):
# NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub')
parsed = urlsplit(model_name)
assert parsed.scheme in ('', 'timm', 'hf-hub')
assert parsed.scheme in ('', 'hf-hub', 'local-dir')
if parsed.scheme == 'hf-hub':
# FIXME may use fragment as revision, currently `@` in URI path
return parsed.scheme, parsed.path
elif parsed.scheme == 'local-dir':
return parsed.scheme, parsed.path
else:
model_name = os.path.split(parsed.path)[-1]
return 'timm', model_name
return None, model_name


def safe_model_name(model_name: str, remove_source: bool = True):
Expand Down Expand Up @@ -100,20 +102,27 @@ def create_model(
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}

model_source, model_name = parse_model_name(model_name)
if model_source == 'hf-hub':
model_source, model_id = parse_model_name(model_name)
if model_source:
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
model_name,
cache_dir=cache_dir,
)
if model_source == 'hf-hub':
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
model_id,
cache_dir=cache_dir,
)
elif model_source == 'local-dir':
pretrained_cfg, model_name, model_args = load_model_config_from_path(
model_id,
)
else:
assert False, f'Unknown model_source {model_source}'
if model_args:
for k, v in model_args.items():
kwargs.setdefault(k, v)
else:
model_name, pretrained_tag = split_model_name_tag(model_name)
model_name, pretrained_tag = split_model_name_tag(model_id)
if pretrained_tag and not pretrained_cfg:
# a valid pretrained_cfg argument takes priority over tag in model name
pretrained_cfg = pretrained_tag
Expand Down
127 changes: 95 additions & 32 deletions timm/models/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
Expand Down Expand Up @@ -157,42 +157,60 @@ def download_from_hf(
)


def _parse_model_cfg(
cfg: Dict[str, Any],
extra_fields: Dict[str, Any],
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
""""""
# legacy "single‑dict" → split
if "pretrained_cfg" not in cfg:
pretrained_cfg = cfg
cfg = {
"architecture": pretrained_cfg.pop("architecture"),
"num_features": pretrained_cfg.pop("num_features", None),
"pretrained_cfg": pretrained_cfg,
}
if "labels" in pretrained_cfg: # rename ‑‑> label_names
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")

pretrained_cfg = cfg["pretrained_cfg"]
pretrained_cfg.update(extra_fields)

# top‑level overrides
if "num_classes" in cfg:
pretrained_cfg["num_classes"] = cfg["num_classes"]
if "label_names" in cfg:
pretrained_cfg["label_names"] = cfg.pop("label_names")
if "label_descriptions" in cfg:
pretrained_cfg["label_descriptions"] = cfg.pop("label_descriptions")

model_args = cfg.get("model_args", {})
model_name = cfg["architecture"]
return pretrained_cfg, model_name, model_args


def load_model_config_from_hf(
model_id: str,
cache_dir: Optional[Union[str, Path]] = None,
):
"""Original HF‑Hub loader (unchanged download, shared parsing)."""
assert has_hf_hub(True)
cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir)

hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
# old form, pull pretrain_cfg out of the base dict
pretrained_cfg = hf_config
hf_config = {}
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
hf_config['pretrained_cfg'] = pretrained_cfg

# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
pretrained_cfg = hf_config['pretrained_cfg']
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
pretrained_cfg['source'] = 'hf-hub'

# model should be created with base config num_classes if its exist
if 'num_classes' in hf_config:
pretrained_cfg['num_classes'] = hf_config['num_classes']

# label meta-data in base config overrides saved pretrained_cfg on load
if 'label_names' in hf_config:
pretrained_cfg['label_names'] = hf_config.pop('label_names')
if 'label_descriptions' in hf_config:
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')

model_args = hf_config.get('model_args', {})
model_name = hf_config['architecture']
return pretrained_cfg, model_name, model_args
cfg_path = download_from_hf(model_id, "config.json", cache_dir=cache_dir)
cfg = load_cfg_from_json(cfg_path)
return _parse_model_cfg(cfg, {"hf_hub_id": model_id, "source": "hf-hub"})


def load_model_config_from_path(
model_path: Union[str, Path],
):
"""Load from ``<model_path>/config.json`` on the local filesystem."""
model_path = Path(model_path)
cfg_file = model_path / "config.json"
if not cfg_file.is_file():
raise FileNotFoundError(f"Config file not found: {cfg_file}")
cfg = load_cfg_from_json(cfg_file)
extra_fields = {"file": str(model_path), "source": "local-dir"}
return _parse_model_cfg(cfg, extra_fields=extra_fields)


def load_state_dict_from_hf(
Expand Down Expand Up @@ -236,6 +254,51 @@ def load_state_dict_from_hf(
return state_dict


_PREFERRED_FILES = (
"model.safetensors",
"pytorch_model.bin",
"pytorch_model.pth",
"model.pth",
"open_clip_model.safetensors",
"open_clip_pytorch_model.safetensors",
"open_clip_pytorch_model.bin",
"open_clip_pytorch_model.pth",
)
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')

def load_state_dict_from_path(
path: str,
weights_only: bool = False,
):
found_file = None
for fname in _PREFERRED_FILES:
p = path / fname
if p.exists():
logging.info(f"Found preferred checkpoint: {p.name}")
found_file = p
break

# fallback: first match per‑extension class
for ext in _EXT_PRIORITY:
files = sorted(path.glob(f"*{ext}"))
if files:
if len(files) > 1:
logging.warning(
f"Multiple {ext} checkpoints in {path}: {names}. "
f"Using '{files[0].name}'."
)
found_file = files[0]

if not found_file:
raise RuntimeError(f"No suitable checkpoints found in {path}.")

try:
state_dict = torch.load(found_file, map_location='cpu', weights_only=weights_only)
except TypeError:
state_dict = torch.load(found_file, map_location='cpu')
return state_dict


def load_custom_from_hf(
model_id: str,
filename: str,
Expand Down