Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions composer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from composer.models.resnet9_cifar10 import CIFARResNet9Hparams as CIFARResNet9Hparams
from composer.models.resnet18 import ResNet18 as ResNet18
from composer.models.resnet18 import ResNet18Hparams as ResNet18Hparams
from composer.models.resnet20_cifar10 import CIFAR10_ResNet20 as CIFAR10_ResNet20
from composer.models.resnet20_cifar10 import CIFARResNet20Hparams as CIFARResNet20Hparams
from composer.models.resnet50 import ResNet50 as ResNet50
from composer.models.resnet50 import ResNet50Hparams as ResNet50Hparams
from composer.models.resnet56_cifar10 import CIFAR10_ResNet56 as CIFAR10_ResNet56
Expand Down
12 changes: 12 additions & 0 deletions composer/models/resnet20_cifar10/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from composer.models.resnet20_cifar10.model import CIFAR10_ResNet20 as CIFAR10_ResNet20
from composer.models.resnet20_cifar10.resnet20_cifar10_hparams import CIFARResNet20Hparams as CIFARResNet20Hparams

_task = 'Image Classification'
_dataset = 'CIFAR10'
_name = 'ResNet20'
_quality = 'tbd'
_metric = 'Top-1 Accuracy'
_ttt = 'tbd'
_hparams = 'resnet20_cifar10.yaml'
36 changes: 36 additions & 0 deletions composer/models/resnet20_cifar10/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from typing import List, Optional

from composer.models.base import MosaicClassifier
from composer.models.model_hparams import Initializer
from composer.models.resnets import CIFAR_ResNet


class CIFAR10_ResNet20(MosaicClassifier):
"""A ResNet-20 model extending :class:`MosaicClassifier`.

See this `paper <https://arxiv.org/abs/1512.03385>`_ for details
on the residual network architecture.

Args:
num_classes (int): The number of classes for the model. Default = 10.
initializers (List[Initializer], optional): Initializers
for the model. ``None`` for no initialization.
(default: ``None``)
"""

def __init__(
self,
num_classes: int = 10,
initializers: Optional[List[Initializer]] = None,
) -> None:
if initializers is None:
initializers = []

model = CIFAR_ResNet.get_model_from_name(
"cifar_resnet_20",
initializers,
num_classes,
)
super().__init__(module=model)
13 changes: 13 additions & 0 deletions composer/models/resnet20_cifar10/resnet20_cifar10_hparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from dataclasses import asdict, dataclass

from composer.models.model_hparams import ModelHparams


@dataclass
class CIFARResNet20Hparams(ModelHparams):

def initialize_object(self):
from composer.models import CIFAR10_ResNet20
return CIFAR10_ResNet20(**asdict(self))
2 changes: 1 addition & 1 deletion composer/models/resnet56_cifar10/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
_name = 'ResNet56'
_quality = '93.1'
_metric = 'Top-1 Accuracy'
_ttt = '15m'
_ttt = 'tbd'
_hparams = 'resnet56_cifar10.yaml'
2 changes: 1 addition & 1 deletion composer/models/resnet56_cifar10/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CIFAR10_ResNet56(MosaicClassifier):
on the residual network architecture.

Args:
num_classes (int): The number of classes for the model.
num_classes (int): The number of classes for the model. Default = 10.
initializers (List[Initializer], optional): Initializers
for the model. ``None`` for no initialization.
(default: ``None``)
Expand Down
3 changes: 2 additions & 1 deletion composer/models/resnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def forward(self, x):
@staticmethod
def is_valid_model_name(model_name):
return (model_name.startswith('cifar_resnet_') and 4 >= len(model_name.split('_')) >= 3 and
model_name.split('_')[2].isdigit() and int(model_name.split('_')[2]) in [56])
model_name.split('_')[2].isdigit() and int(model_name.split('_')[2]) in [20,56])

@staticmethod
def get_model_from_name(model_name, initializers: List[Initializer], outputs=10):
Expand All @@ -327,6 +327,7 @@ def get_model_from_name(model_name, initializers: List[Initializer], outputs=10)

model_arch = {
56: [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)],
20: [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)],
}

return CIFAR_ResNet(model_arch[depth], initializers, outputs)
2 changes: 2 additions & 0 deletions composer/trainer/trainer_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from composer.models import (BERTForClassificationHparams, BERTHparams, CIFARResNet9Hparams, CIFARResNetHparams,
DeepLabV3Hparams, EfficientNetB0Hparams, GPT2Hparams, MnistClassifierHparams, ModelHparams,
ResNet18Hparams, ResNet50Hparams, ResNet101Hparams, UnetHparams)
from composer.models.resnet20_cifar10.resnet20_cifar10_hparams import CIFARResNet20Hparams
from composer.optim import (AdamHparams, AdamWHparams, DecoupledAdamWHparams, DecoupledSGDWHparams, OptimizerHparams,
RAdamHparams, RMSPropHparams, SchedulerHparams, SGDHparams, scheduler)
from composer.profiler import ProfilerHparams
Expand Down Expand Up @@ -63,6 +64,7 @@
"deeplabv3": DeepLabV3Hparams,
"efficientnetb0": EfficientNetB0Hparams,
"resnet56_cifar10": CIFARResNetHparams,
"resnet20_cifar10": CIFARResNet20Hparams,
"resnet9_cifar10": CIFARResNet9Hparams,
"resnet101": ResNet101Hparams,
"resnet50": ResNet50Hparams,
Expand Down
54 changes: 54 additions & 0 deletions composer/yamls/models/resnet20_cifar10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
train_dataset:
cifar10:
datadir: /datasets/CIFAR10
is_train: true
download: false
shuffle: true
drop_last: true
val_dataset:
cifar10:
datadir: /datasets/CIFAR10
is_train: false
download: false
shuffle: false
drop_last: false
optimizer:
decoupled_sgdw:
lr: 1.2
momentum: 0.9
weight_decay: 2.0e-3
schedulers:
- warmup:
warmup_iters: "5ep"
warmup_method: linear
warmup_factor: 0
verbose: false
interval: step
- multistep:
milestones:
- "80ep"
- "120ep"
gamma: 0.1
interval: epoch
model:
resnet20_cifar10:
initializers:
- kaiming_normal
- bn_uniform
loggers:
- tqdm: {}
max_duration: 160ep
train_batch_size: 1024
eval_batch_size: 1000
seed: 17
validate_every_n_epochs: 1
grad_accum: 1
device:
gpu: {}
dataloader:
pin_memory: true
timeout: 0
prefetch_factor: 2
persistent_workers: true
num_workers: 8
precision: amp
1 change: 0 additions & 1 deletion composer/yamls/models/resnet56_cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ model:
initializers:
- kaiming_normal
- bn_uniform
num_classes: 10
loggers:
- tqdm: {}
max_duration: 160ep
Expand Down
1 change: 0 additions & 1 deletion composer/yamls/models/resnet56_cifar10_synthetic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ model:
initializers:
- kaiming_normal
- bn_uniform
num_classes: 10
loggers:
- tqdm: {}
max_duration: 160ep
Expand Down