|
8 | 8 | Each algorithm is keyed based on its name in the algorithm registry.
|
9 | 9 | """
|
10 | 10 |
|
11 |
| -from composer.algorithms import algorithm_registry |
| 11 | +from typing import Any, Dict, Optional, Type |
| 12 | + |
| 13 | +import pytest |
| 14 | +from torch.utils.data import Dataset |
| 15 | + |
| 16 | +import composer |
| 17 | +from composer import Algorithm |
| 18 | +from composer.algorithms import (AGC, EMA, SAM, SWA, Alibi, AugMix, BlurPool, ChannelsLast, ColOut, CutMix, CutOut, |
| 19 | + Factorize, GhostBatchNorm, LabelSmoothing, LayerFreezing, MixUp, NoOpModel, |
| 20 | + ProgressiveResizing, RandAugment, SelectiveBackprop, SeqLengthWarmup, SqueezeExcite, |
| 21 | + StochasticDepth) |
12 | 22 | from composer.models import ComposerResNet
|
| 23 | +from composer.models.base import ComposerModel |
13 | 24 | from tests import common
|
14 | 25 |
|
15 | 26 | simple_vision_settings = {
|
|
23 | 34 | 'dataset': (common.RandomImageDataset, {
|
24 | 35 | 'is_PIL': True
|
25 | 36 | }),
|
26 |
| - 'kwargs': {} |
| 37 | + 'kwargs': {}, |
27 | 38 | }
|
28 | 39 |
|
29 | 40 | simple_resnet_settings = {
|
|
34 | 45 | 'dataset': (common.RandomImageDataset, {
|
35 | 46 | 'shape': (3, 224, 224),
|
36 | 47 | }),
|
| 48 | + 'kwargs': {}, |
37 | 49 | }
|
38 | 50 |
|
39 |
| -_settings = { |
40 |
| - 'agc': simple_vision_settings, |
41 |
| - 'alibi': None, # NLP settings needed |
42 |
| - 'augmix': None, # requires PIL dataset to test |
43 |
| - 'blurpool': { |
| 51 | +_settings: Dict[Type[Algorithm], Optional[Dict[str, Any]]] = { |
| 52 | + AGC: simple_vision_settings, |
| 53 | + Alibi: None, # NLP settings needed |
| 54 | + AugMix: simple_vision_settings, |
| 55 | + BlurPool: { |
44 | 56 | 'model': common.SimpleConvModel,
|
45 | 57 | 'dataset': common.RandomImageDataset,
|
46 | 58 | 'kwargs': {
|
47 | 59 | 'min_channels': 0,
|
48 | 60 | },
|
49 | 61 | },
|
50 |
| - 'channels_last': simple_vision_settings, |
51 |
| - 'colout': simple_vision_settings, |
52 |
| - 'cutmix': { |
| 62 | + ChannelsLast: simple_vision_settings, |
| 63 | + ColOut: simple_vision_settings, |
| 64 | + CutMix: { |
53 | 65 | 'model': common.SimpleConvModel,
|
54 | 66 | 'dataset': common.RandomImageDataset,
|
55 | 67 | 'kwargs': {
|
56 | 68 | 'num_classes': 2
|
57 | 69 | }
|
58 | 70 | },
|
59 |
| - 'cutout': simple_vision_settings, |
60 |
| - 'ema': simple_vision_settings, |
61 |
| - 'factorize': simple_resnet_settings, |
62 |
| - 'ghost_batchnorm': { |
| 71 | + CutOut: simple_vision_settings, |
| 72 | + EMA: { |
| 73 | + 'model': common.SimpleConvModel, |
| 74 | + 'dataset': common.RandomImageDataset, |
| 75 | + 'kwargs': { |
| 76 | + 'half_life': "1ba", |
| 77 | + }, |
| 78 | + }, |
| 79 | + Factorize: simple_resnet_settings, |
| 80 | + GhostBatchNorm: { |
63 | 81 | 'model': (ComposerResNet, {
|
64 | 82 | 'model_name': 'resnet18',
|
65 | 83 | 'num_classes': 2
|
|
71 | 89 | 'ghost_batch_size': 2,
|
72 | 90 | }
|
73 | 91 | },
|
74 |
| - 'label_smoothing': simple_vision_settings, |
75 |
| - 'layer_freezing': simple_vision_settings, |
76 |
| - 'mixup': simple_vision_settings, |
77 |
| - 'progressive_resizing': simple_vision_settings, |
78 |
| - 'randaugment': None, # requires PIL dataset to test |
79 |
| - 'sam': simple_vision_settings, |
80 |
| - 'selective_backprop': simple_vision_settings, |
81 |
| - 'seq_length_warmup': None, # NLP settings needed |
82 |
| - 'squeeze_excite': simple_resnet_settings, |
83 |
| - 'stochastic_depth': { |
| 92 | + LabelSmoothing: simple_vision_settings, |
| 93 | + LayerFreezing: simple_vision_settings, |
| 94 | + MixUp: simple_vision_settings, |
| 95 | + ProgressiveResizing: simple_vision_settings, |
| 96 | + RandAugment: simple_vision_settings, |
| 97 | + NoOpModel: simple_vision_settings, |
| 98 | + SAM: simple_vision_settings, |
| 99 | + SelectiveBackprop: simple_vision_settings, |
| 100 | + SeqLengthWarmup: None, # NLP settings needed |
| 101 | + SqueezeExcite: simple_resnet_settings, |
| 102 | + StochasticDepth: { |
84 | 103 | 'model': (ComposerResNet, {
|
85 | 104 | 'model_name': 'resnet50',
|
86 | 105 | 'num_classes': 2
|
|
93 | 112 | 'target_layer_name': 'ResNetBottleneck',
|
94 | 113 | 'drop_rate': 0.2,
|
95 | 114 | 'drop_distribution': 'linear',
|
96 |
| - 'use_same_gpu_seed': False |
| 115 | + 'drop_warmup': "0.0dur", |
| 116 | + 'use_same_gpu_seed': False, |
97 | 117 | }
|
98 | 118 | },
|
99 |
| - 'swa': { |
| 119 | + SWA: { |
100 | 120 | 'model': common.SimpleConvModel,
|
101 | 121 | 'dataset': common.RandomImageDataset,
|
102 | 122 | 'kwargs': {
|
|
105 | 125 | 'update_interval': '1ep',
|
106 | 126 | 'schedule_swa_lr': True,
|
107 | 127 | }
|
108 |
| - } |
| 128 | + }, |
109 | 129 | }
|
110 | 130 |
|
111 | 131 |
|
112 |
| -def get_settings(name: str): |
113 |
| - """For a given algorithm name, creates the canonical setting |
114 |
| - (algorithm, model, dataset) for testing. |
| 132 | +def _get_alg_settings(alg_cls: Type[Algorithm]): |
| 133 | + if alg_cls not in _settings or _settings[alg_cls] is None: |
| 134 | + raise ValueError(f"Algorithm {alg_cls.__name__} not in the settings dictionary.") |
| 135 | + settings = _settings[alg_cls] |
| 136 | + assert settings is not None |
| 137 | + return settings |
| 138 | + |
| 139 | + |
| 140 | +def get_alg_kwargs(alg_cls: Type[Algorithm]) -> Dict[str, Any]: |
| 141 | + """Return the kwargs for an algorithm.""" |
| 142 | + return _get_alg_settings(alg_cls)['kwargs'] |
115 | 143 |
|
116 |
| - Returns ``None`` if no settings provided. |
| 144 | + |
| 145 | +def get_alg_model(alg_cls: Type[Algorithm]) -> ComposerModel: |
| 146 | + """Return an instance of the model for an algorithm.""" |
| 147 | + settings = _get_alg_settings(alg_cls)['model'] |
| 148 | + if isinstance(settings, tuple): |
| 149 | + (cls, kwargs) = settings |
| 150 | + else: |
| 151 | + (cls, kwargs) = (settings, {}) |
| 152 | + return cls(**kwargs) |
| 153 | + |
| 154 | + |
| 155 | +def get_alg_dataset(alg_cls: Type[Algorithm]) -> Dataset: |
| 156 | + """Return an instance of the dataset for an algorithm.""" |
| 157 | + settings = _get_alg_settings(alg_cls)['dataset'] |
| 158 | + if isinstance(settings, tuple): |
| 159 | + (cls, kwargs) = settings |
| 160 | + else: |
| 161 | + (cls, kwargs) = (settings, {}) |
| 162 | + return cls(**kwargs) |
| 163 | + |
| 164 | + |
| 165 | +def get_algs_with_marks(): |
| 166 | + """Returns a list of algorithms appropriate markers for a subsequent call to pytest.mark.parameterize. |
| 167 | + It applies markers as appropriate (e.g. XFAIL for algs missing config) |
| 168 | + It reads from the algorithm registry |
| 169 | +
|
| 170 | + E.g. @pytest.mark.parametrize("alg_class", get_algs_with_marks()) |
117 | 171 | """
|
118 |
| - if name not in _settings: |
119 |
| - raise ValueError(f'No settings for {name} found, please add.') |
120 |
| - |
121 |
| - setting = _settings[name] |
122 |
| - if setting is None: |
123 |
| - return None |
124 |
| - |
125 |
| - result = {} |
126 |
| - for key in ('model', 'dataset'): |
127 |
| - if isinstance(setting[key], tuple): |
128 |
| - (obj, kwargs) = setting[key] |
129 |
| - else: |
130 |
| - (obj, kwargs) = (setting[key], {}) |
131 |
| - |
132 |
| - # create the object |
133 |
| - result[key] = obj(**kwargs) |
134 |
| - |
135 |
| - # create algorithm |
136 |
| - kwargs = setting.get('kwargs', {}) |
137 |
| - hparams = algorithm_registry.get_algorithm_registry()[name] |
138 |
| - result['algorithm'] = hparams(**kwargs).initialize_object() |
139 |
| - result['algorithm_kwargs'] = kwargs |
140 |
| - |
141 |
| - return result |
| 172 | + ans = [] |
| 173 | + for alg_cls in common.get_module_subclasses(composer.algorithms, Algorithm): |
| 174 | + marks = [] |
| 175 | + settings = _settings[alg_cls] |
| 176 | + |
| 177 | + if alg_cls in (CutMix, MixUp, LabelSmoothing): |
| 178 | + # see: https://github.com/mosaicml/composer/issues/362 |
| 179 | + pytest.importorskip("torch", minversion="1.10", reason="Pytorch 1.10 required.") |
| 180 | + |
| 181 | + if alg_cls == SWA: |
| 182 | + # TODO(matthew): Fix |
| 183 | + marks.append( |
| 184 | + pytest.mark.filterwarnings( |
| 185 | + r'ignore:Detected call of `lr_scheduler.step\(\)` before `optimizer.step\(\)`:UserWarning')) |
| 186 | + |
| 187 | + if alg_cls == MixUp: |
| 188 | + # TODO(Landen): Fix |
| 189 | + marks.append( |
| 190 | + pytest.mark.filterwarnings(r"ignore:Some targets have less than 1 total probability:UserWarning")) |
| 191 | + |
| 192 | + if settings is None: |
| 193 | + marks.append(pytest.mark.xfail(reason=f"Algorithm {alg_cls.__name__} is missing settings.")) |
| 194 | + |
| 195 | + ans.append(pytest.param(alg_cls, marks=marks, id=alg_cls.__name__)) |
| 196 | + |
| 197 | + return ans |
0 commit comments