-
Notifications
You must be signed in to change notification settings - Fork 455
Timm support #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Timm support #262
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
a6c64d8
timm hparams
A-Jacobson 83057ec
timm resnet50 yaml
A-Jacobson 707c4e5
fix hparams interface
A-Jacobson a2b6dd6
optional typing
A-Jacobson 640909d
model -> model_name
A-Jacobson af4d16a
model -> model_name
A-Jacobson 2409e31
timm model wrapper
A-Jacobson 525c211
add timm to __init__
A-Jacobson 71a9330
train -> total batch size
A-Jacobson 124a973
timm model wrapper
A-Jacobson d63d608
back to train batchsize
A-Jacobson 1985f60
Merge branch 'dev' into timm-support
A-Jacobson ce5bc2a
Update model.py
A-Jacobson bcd7cc9
Update timm_hparams.py
A-Jacobson a8d9bbd
Update setup.py
A-Jacobson dd8e6d1
Merge branch 'dev' into timm-support
hanlint 40726a5
sort imports
A-Jacobson f1564f0
run yapf
A-Jacobson 6a13fc8
Update composer/models/timm/model.py
A-Jacobson c178612
update docstring
A-Jacobson 549cad9
Merge branch 'dev' into timm-support
A-Jacobson 5f6b890
pull dev
A-Jacobson 66733cf
fix merge conflict
A-Jacobson 6b396a8
add license
A-Jacobson fe6fcac
timm registry test
A-Jacobson 3e2a746
skip timm tests if timm isn't installed
A-Jacobson 932bea9
lint ignore lines
A-Jacobson f593eca
lint
A-Jacobson 7e6258f
lint test
A-Jacobson 9e6b72e
Merge branch 'dev' into timm-support
A-Jacobson 77a9924
don't skip non-timm tests
A-Jacobson db6c4eb
Merge branch 'dev' into timm-support
A-Jacobson af182eb
fix imports
A-Jacobson a8ce15c
Merge branch 'dev' into timm-support
hanlint 7d75753
Merge branch 'dev' into timm-support
hanlint a18c756
fix importorskip
hanlint f4c4758
cleanup
hanlint 26e5081
Merge branch 'dev' into timm-support
hanlint File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright 2021 MosaicML. All Rights Reserved. | ||
from composer.models.timm.model import Timm as Timm | ||
from composer.models.timm.timm_hparams import TimmHparams as TimmHparams |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright 2021 MosaicML. All Rights Reserved. | ||
from typing import Optional | ||
|
||
from composer.models.base import MosaicClassifier | ||
|
||
|
||
class Timm(MosaicClassifier): | ||
"""A wrapper around timm.create_model() used to create mosaic classifiers from timm models | ||
Args: | ||
model_name (str): timm model name e.g:'resnet50'list of models can be found at https://github.com/rwightman/pytorch-image-models | ||
pretrained (bool): imagenet pretrained. default: False | ||
num_classes (int): The number of classes. Needed for classification tasks. default: 1000 | ||
drop_rate (float): dropout rate. default: 0.0 | ||
drop_path_rate (float): drop path rate (model default if None). default: None | ||
drop_block_rate (float): drop block rate (model default if None). default: None | ||
global_pool (str): Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None. default: None | ||
bn_momentum (float): BatchNorm momentum override (model default if not None). default: None | ||
bn_eps (float): BatchNorm epsilon override (model default if not None). default: None | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
pretrained: bool = False, | ||
num_classes: int = 1000, | ||
drop_rate: float = 0.0, | ||
drop_path_rate: Optional[float] = None, | ||
drop_block_rate: Optional[float] = None, | ||
global_pool: Optional[str] = None, | ||
bn_momentum: Optional[float] = None, | ||
bn_eps: Optional[float] = None, | ||
) -> None: | ||
import timm | ||
|
||
model = timm.create_model( | ||
model_name=model_name, | ||
pretrained=pretrained, | ||
num_classes=num_classes, | ||
drop_rate=drop_rate, | ||
drop_path_rate=drop_path_rate, | ||
drop_block_rate=drop_block_rate, | ||
global_pool=global_pool, | ||
bn_momentum=bn_momentum, | ||
bn_eps=bn_eps, | ||
) | ||
super().__init__(module=model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2021 MosaicML. All Rights Reserved. | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import yahp as hp | ||
|
||
from composer.models.model_hparams import ModelHparams | ||
from composer.models.timm.model import Timm | ||
|
||
|
||
@dataclass | ||
class TimmHparams(ModelHparams): | ||
|
||
model_name: str = hp.optional( | ||
"timm model name e.g: 'resnet50', list of models can be found at https://github.com/rwightman/pytorch-image-models", | ||
default=None, | ||
) | ||
pretrained: bool = hp.optional("imagenet pretrained", default=False) | ||
num_classes: int = hp.optional("The number of classes. Needed for classification tasks", default=1000) | ||
drop_rate: float = hp.optional("dropout rate", default=0.0) | ||
drop_path_rate: Optional[float] = hp.optional("drop path rate (model default if None)", default=None) | ||
drop_block_rate: Optional[float] = hp.optional("drop block rate (model default if None)", default=None) | ||
global_pool: Optional[str] = hp.optional( | ||
"Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.", default=None) | ||
bn_momentum: Optional[float] = hp.optional("BatchNorm momentum override (model default if not None)", default=None) | ||
bn_eps: Optional[float] = hp.optional("BatchNorm epsilon override (model default if not None)", default=None) | ||
|
||
def validate(self): | ||
if self.model_name is None: | ||
import timm | ||
raise ValueError(f"model must be one of {timm.models.list_models()}") | ||
|
||
def initialize_object(self): | ||
return Timm(model_name=self.model_name, | ||
pretrained=self.pretrained, | ||
num_classes=self.num_classes, | ||
drop_rate=self.drop_rate, | ||
drop_path_rate=self.drop_path_rate, | ||
drop_block_rate=self.drop_block_rate, | ||
global_pool=self.global_pool, | ||
bn_momentum=self.bn_momentum, | ||
bn_eps=self.bn_eps) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
train_dataset: | ||
imagenet: | ||
resize_size: -1 | ||
crop_size: 224 | ||
is_train: true | ||
datadir: /datasets/ImageNet | ||
shuffle: true | ||
drop_last: true | ||
val_dataset: | ||
imagenet: | ||
resize_size: 256 | ||
crop_size: 224 | ||
is_train: false | ||
datadir: /datasets/ImageNet | ||
shuffle: false | ||
drop_last: false | ||
optimizer: | ||
decoupled_sgdw: | ||
lr: 2.048 | ||
momentum: 0.875 | ||
weight_decay: 5.0e-4 | ||
dampening: 0 | ||
nesterov: false | ||
schedulers: | ||
- warmup: | ||
warmup_iters: "8ep" | ||
warmup_method: linear | ||
warmup_factor: 0 | ||
verbose: false | ||
interval: step | ||
- cosine_decay: | ||
T_max: "82ep" | ||
eta_min: 0 | ||
verbose: false | ||
interval: step | ||
model: | ||
timm: | ||
model_name: 'resnet50' | ||
num_classes: 1000 | ||
loggers: | ||
- tqdm: {} | ||
max_duration: 90ep | ||
train_batch_size: 2048 | ||
eval_batch_size: 2048 | ||
seed: 17 | ||
device: | ||
gpu: {} | ||
dataloader: | ||
pin_memory: true | ||
timeout: 0 | ||
prefetch_factor: 2 | ||
persistent_workers: true | ||
num_workers: 8 | ||
validate_every_n_epochs: 1 | ||
grad_accum: 1 | ||
precision: amp |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.