Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a6c64d8
timm hparams
A-Jacobson Jan 20, 2022
83057ec
timm resnet50 yaml
A-Jacobson Jan 20, 2022
707c4e5
fix hparams interface
A-Jacobson Jan 20, 2022
a2b6dd6
optional typing
A-Jacobson Jan 20, 2022
640909d
model -> model_name
A-Jacobson Jan 21, 2022
af4d16a
model -> model_name
A-Jacobson Jan 21, 2022
2409e31
timm model wrapper
A-Jacobson Jan 21, 2022
525c211
add timm to __init__
A-Jacobson Jan 21, 2022
71a9330
train -> total batch size
A-Jacobson Jan 21, 2022
124a973
timm model wrapper
A-Jacobson Jan 21, 2022
d63d608
back to train batchsize
A-Jacobson Jan 21, 2022
1985f60
Merge branch 'dev' into timm-support
A-Jacobson Jan 21, 2022
ce5bc2a
Update model.py
A-Jacobson Jan 21, 2022
bcd7cc9
Update timm_hparams.py
A-Jacobson Jan 21, 2022
a8d9bbd
Update setup.py
A-Jacobson Jan 21, 2022
dd8e6d1
Merge branch 'dev' into timm-support
hanlint Jan 24, 2022
40726a5
sort imports
A-Jacobson Jan 25, 2022
f1564f0
run yapf
A-Jacobson Jan 28, 2022
6a13fc8
Update composer/models/timm/model.py
A-Jacobson Jan 28, 2022
c178612
update docstring
A-Jacobson Jan 28, 2022
549cad9
Merge branch 'dev' into timm-support
A-Jacobson Jan 28, 2022
5f6b890
pull dev
A-Jacobson Jan 28, 2022
66733cf
fix merge conflict
A-Jacobson Jan 28, 2022
6b396a8
add license
A-Jacobson Jan 28, 2022
fe6fcac
timm registry test
A-Jacobson Jan 28, 2022
3e2a746
skip timm tests if timm isn't installed
A-Jacobson Jan 29, 2022
932bea9
lint ignore lines
A-Jacobson Jan 29, 2022
f593eca
lint
A-Jacobson Jan 29, 2022
7e6258f
lint test
A-Jacobson Jan 29, 2022
9e6b72e
Merge branch 'dev' into timm-support
A-Jacobson Jan 29, 2022
77a9924
don't skip non-timm tests
A-Jacobson Jan 31, 2022
db6c4eb
Merge branch 'dev' into timm-support
A-Jacobson Jan 31, 2022
af182eb
fix imports
A-Jacobson Jan 31, 2022
a8ce15c
Merge branch 'dev' into timm-support
hanlint Feb 1, 2022
7d75753
Merge branch 'dev' into timm-support
hanlint Feb 1, 2022
a18c756
fix importorskip
hanlint Feb 1, 2022
f4c4758
cleanup
hanlint Feb 1, 2022
26e5081
Merge branch 'dev' into timm-support
hanlint Feb 1, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions composer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from composer.models.resnet56_cifar10 import CIFARResNetHparams as CIFARResNetHparams
from composer.models.resnet101 import ResNet101 as ResNet101
from composer.models.resnet101 import ResNet101Hparams as ResNet101Hparams
from composer.models.timm import Timm as Timm
from composer.models.timm import TimmHparams as TimmHparams
from composer.models.transformer_shared import MosaicTransformer as MosaicTransformer
from composer.models.unet import UNet as UNet
from composer.models.unet import UnetHparams as UnetHparams
3 changes: 3 additions & 0 deletions composer/models/timm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright 2021 MosaicML. All Rights Reserved.
from composer.models.timm.model import Timm as Timm
from composer.models.timm.timm_hparams import TimmHparams as TimmHparams
46 changes: 46 additions & 0 deletions composer/models/timm/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2021 MosaicML. All Rights Reserved.
from typing import Optional

from composer.models.base import MosaicClassifier


class Timm(MosaicClassifier):
"""A wrapper around timm.create_model() used to create mosaic classifiers from timm models
Args:
model_name (str): timm model name e.g:'resnet50'list of models can be found at https://github.com/rwightman/pytorch-image-models
pretrained (bool): imagenet pretrained. default: False
num_classes (int): The number of classes. Needed for classification tasks. default: 1000
drop_rate (float): dropout rate. default: 0.0
drop_path_rate (float): drop path rate (model default if None). default: None
drop_block_rate (float): drop block rate (model default if None). default: None
global_pool (str): Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None. default: None
bn_momentum (float): BatchNorm momentum override (model default if not None). default: None
bn_eps (float): BatchNorm epsilon override (model default if not None). default: None
"""

def __init__(
self,
model_name: str,
pretrained: bool = False,
num_classes: int = 1000,
drop_rate: float = 0.0,
drop_path_rate: Optional[float] = None,
drop_block_rate: Optional[float] = None,
global_pool: Optional[str] = None,
bn_momentum: Optional[float] = None,
bn_eps: Optional[float] = None,
) -> None:
import timm

model = timm.create_model(
model_name=model_name,
pretrained=pretrained,
num_classes=num_classes,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
drop_block_rate=drop_block_rate,
global_pool=global_pool,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
)
super().__init__(module=model)
42 changes: 42 additions & 0 deletions composer/models/timm/timm_hparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2021 MosaicML. All Rights Reserved.
from dataclasses import dataclass
from typing import Optional

import yahp as hp

from composer.models.model_hparams import ModelHparams
from composer.models.timm.model import Timm


@dataclass
class TimmHparams(ModelHparams):

model_name: str = hp.optional(
"timm model name e.g: 'resnet50', list of models can be found at https://github.com/rwightman/pytorch-image-models",
default=None,
)
pretrained: bool = hp.optional("imagenet pretrained", default=False)
num_classes: int = hp.optional("The number of classes. Needed for classification tasks", default=1000)
drop_rate: float = hp.optional("dropout rate", default=0.0)
drop_path_rate: Optional[float] = hp.optional("drop path rate (model default if None)", default=None)
drop_block_rate: Optional[float] = hp.optional("drop block rate (model default if None)", default=None)
global_pool: Optional[str] = hp.optional(
"Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.", default=None)
bn_momentum: Optional[float] = hp.optional("BatchNorm momentum override (model default if not None)", default=None)
bn_eps: Optional[float] = hp.optional("BatchNorm epsilon override (model default if not None)", default=None)

def validate(self):
if self.model_name is None:
import timm
raise ValueError(f"model must be one of {timm.models.list_models()}")

def initialize_object(self):
return Timm(model_name=self.model_name,
pretrained=self.pretrained,
num_classes=self.num_classes,
drop_rate=self.drop_rate,
drop_path_rate=self.drop_path_rate,
drop_block_rate=self.drop_block_rate,
global_pool=self.global_pool,
bn_momentum=self.bn_momentum,
bn_eps=self.bn_eps)
9 changes: 5 additions & 4 deletions composer/trainer/trainer_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from composer.datasets import DataloaderHparams
from composer.loggers import (BaseLoggerBackendHparams, FileLoggerBackendHparams, MosaicMLLoggerBackendHparams,
TQDMLoggerBackendHparams, WandBLoggerBackendHparams)
from composer.models import (BERTForClassificationHparams, BERTHparams, CIFARResNet9Hparams, CIFARResNetHparams,
DeepLabV3Hparams, EfficientNetB0Hparams, GPT2Hparams, MnistClassifierHparams, ModelHparams,
ResNet18Hparams, ResNet50Hparams, ResNet101Hparams, UnetHparams)
from composer.models.resnet20_cifar10.resnet20_cifar10_hparams import CIFARResNet20Hparams
from composer.models import (BERTForClassificationHparams, BERTHparams, CIFARResNet9Hparams, CIFARResNet20Hparams,
CIFARResNetHparams, DeepLabV3Hparams, EfficientNetB0Hparams, GPT2Hparams,
MnistClassifierHparams, ModelHparams, ResNet18Hparams, ResNet50Hparams, ResNet101Hparams,
TimmHparams, UnetHparams)
from composer.optim import (AdamHparams, AdamWHparams, DecoupledAdamWHparams, DecoupledSGDWHparams, OptimizerHparams,
RAdamHparams, RMSPropHparams, SchedulerHparams, SGDHparams, scheduler)
from composer.profiler import ProfilerHparams
Expand Down Expand Up @@ -73,6 +73,7 @@
"gpt2": GPT2Hparams,
"bert": BERTHparams,
"bert_classification": BERTForClassificationHparams,
"timm": TimmHparams
}

dataset_registry = {
Expand Down
56 changes: 56 additions & 0 deletions composer/yamls/models/timm_resnet50_imagenet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
train_dataset:
imagenet:
resize_size: -1
crop_size: 224
is_train: true
datadir: /datasets/ImageNet
shuffle: true
drop_last: true
val_dataset:
imagenet:
resize_size: 256
crop_size: 224
is_train: false
datadir: /datasets/ImageNet
shuffle: false
drop_last: false
optimizer:
decoupled_sgdw:
lr: 2.048
momentum: 0.875
weight_decay: 5.0e-4
dampening: 0
nesterov: false
schedulers:
- warmup:
warmup_iters: "8ep"
warmup_method: linear
warmup_factor: 0
verbose: false
interval: step
- cosine_decay:
T_max: "82ep"
eta_min: 0
verbose: false
interval: step
model:
timm:
model_name: 'resnet50'
num_classes: 1000
loggers:
- tqdm: {}
max_duration: 90ep
train_batch_size: 2048
eval_batch_size: 2048
seed: 17
device:
gpu: {}
dataloader:
pin_memory: true
timeout: 0
prefetch_factor: 2
persistent_workers: true
num_workers: 8
validate_every_n_epochs: 1
grad_accum: 1
precision: amp
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def package_files(directory: str):
'datasets>=1.14.0',
]

extra_deps['vision'] = ['timm>=0.5.4']

extra_deps['unet'] = [
'monai>=0.7.0',
'scikit-learn>=1.0.1',
Expand Down
11 changes: 10 additions & 1 deletion tests/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def walk_model_yamls():

def _configure_dataset_for_synthetic(dataset_hparams: DatasetHparams) -> None:
if not isinstance(dataset_hparams, SyntheticHparamsMixin):
pytest.xfail(f"{dataset_hparams.__class__.__name__} does not support synthetic data or num_total_batchjes")
pytest.xfail(f"{dataset_hparams.__class__.__name__} does not support synthetic data or num_total_batches")

assert isinstance(dataset_hparams, SyntheticHparamsMixin)

Expand All @@ -36,10 +36,19 @@ def _configure_dataset_for_synthetic(dataset_hparams: DatasetHparams) -> None:
class TestHparamsCreate:

def test_hparams_create(self, hparams_file: str):
if hparams_file in ["timm_resnet50_imagenet.yaml"]:
pytest.importorskip("timm")
if hparams_file in ["unet.yaml"]:
pytest.importorskip("monai")

hparams = TrainerHparams.create(hparams_file, cli_args=False)
assert isinstance(hparams, TrainerHparams)

def test_trainer_initialize(self, hparams_file: str):
if hparams_file in ["timm_resnet50_imagenet.yaml"]:
pytest.importorskip("timm")
if hparams_file in ["unet.yaml"]:
pytest.importorskip("monai")
hparams = TrainerHparams.create(hparams_file, cli_args=False)
hparams.dataloader.num_workers = 0
hparams.dataloader.persistent_workers = False
Expand Down
10 changes: 5 additions & 5 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ def get_model_algs(model_name: str) -> List[str]:
if is_image_model:
algs.remove("alibi")
if "alibi" in algs:
try:
import transformers
del transformers
except ImportError:
pytest.skip("Unable to import transformers; skipping alibi")
pytest.importorskip("transformers")
if model_name in ("unet", "gpt2_52m", "gpt2_83m", 'gpt2_125m'):
algs.remove("mixup")
algs.remove("cutmix")
Expand All @@ -39,6 +35,10 @@ def get_model_algs(model_name: str) -> List[str]:
@pytest.mark.parametrize('model_name', model_names)
@pytest.mark.timeout(15)
def test_load(model_name: str):
if model_name in ['timm']:
pytest.importorskip("timm")
if model_name in ['unet']:
pytest.importorskip("monai")
trainer_hparams = trainer.load(model_name)
trainer_hparams.precision = Precision.FP32
trainer_hparams.algorithms = algorithms.load_multiple(*get_model_algs(model_name))
Expand Down
19 changes: 9 additions & 10 deletions tests/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@

import pytest

from composer.models import BaseMosaicModel, ModelHparams
from composer.models import ModelHparams
from composer.trainer.trainer_hparams import model_registry


@pytest.mark.parametrize("model_name", model_registry.keys())
def test_model_registry(model_name, request):
if model_name in ['timm']:
pytest.importorskip("timm")
if model_name in ['unet']:
pytest.importorskip("monai")

# TODO (Moin + Ravi): create dummy versions of these models to pass unit tests
if model_name in ['gpt2', 'bert', 'bert_classification']: # do not pull from HF model hub
request.applymarker(pytest.mark.xfail())
Expand All @@ -31,13 +36,7 @@ def test_model_registry(model_name, request):
if model_name == "deeplabv3":
model_hparams.is_backbone_pretrained = False

assert isinstance(model_hparams, ModelHparams)
if model_name == "timm":
model_hparams.model_name = "resnet18"

try:
# create the model object using the hparams
model = model_hparams.initialize_object()
assert isinstance(model, BaseMosaicModel)
except ModuleNotFoundError as e:
if model_name == "unet" and e.name == 'monai':
pytest.skip("Unet not installed -- skipping")
raise e
assert isinstance(model_hparams, ModelHparams)