Skip to content

Commit 9387d5c

Browse files
Remove plural types and aliases for native pytorch types (#677)
* First pass fix ensure_tuple * Remove endline * Adjust state docstring * Adjust docstring * Address comments * Ooops typo * Add tests * Lint * Fix types * Fixed type errors * Docstring styling * Errr docstring links? * Attempt at removing plural types pt. 1 * Remove plural types pt. 2 * Clean up types.py Remove re-exported types from types.py. This should help with code readability, as now the underlying type is directly imported. * Rearranged import * Fix circular import * Fix doctests * Replaced StateDict type annotation with `Dict[str, Any]` `Dict[str, Any]` is almost as short but clearer on what it is (a dictionary with string keys). * Removed some more plural types * Cleanup tensor_to_device * Fixed typo * Allowing generic sequences * Fix LoggerDestination issues Co-authored-by: ravi-mosaicml <[email protected]>
1 parent 5bff422 commit 9387d5c

File tree

98 files changed

+512
-502
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+512
-502
lines changed

STYLE_GUIDE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,13 @@ The following rules apply to public APIs:
119119

120120
1. Parameters that could take a sequence of elements should also allow `None` or a singleton.
121121
This simplifies the user API by not having to construct a list (or tuple) to hold a single element
122-
(or no element). For example, `Tensors = Union[Tensor, Tuple[Tensor, ...], List[Tensor]]`.
122+
(or no element). For example, use `Optional[Union[torch.Tensor, Sequence[torch.Tensor]]`.
123123
124124
The `composer.utils.ensure_tuple` helper method can convert a singleton, list, or tuple into a tuple.
125125
For example
126126
127127
```python
128-
def foo(x: Optional[Tensors]) -> Tuple[Tensor, ...]:
128+
def foo(x: Optional[Union[Tensor, Sequence[Tensor]]) -> Tuple[Tensor, ...]:
129129
return ensure_tuple(x) # ensures that the result is always a (potentially empty) tuple of tensors
130130
```
131131

composer/algorithms/alibi/alibi.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import math
1010
from operator import attrgetter
1111
from types import MethodType, ModuleType
12-
from typing import Any, Callable, Optional, Tuple, Type, Union, cast
12+
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, cast
1313

1414
import torch
15+
from torch.optim import Optimizer
1516

1617
from composer.core import Algorithm, Event, State
17-
from composer.core.types import Optimizers
1818
from composer.loggers import Logger
1919
from composer.utils import module_surgery
2020

@@ -32,7 +32,7 @@ def apply_alibi(
3232
attr_to_replace: str,
3333
alibi_attention: Callable,
3434
mask_replacement_function: Optional[Callable[[torch.nn.Module, int], torch.nn.Module]] = None,
35-
optimizers: Optional[Optimizers] = None,
35+
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
3636
) -> None:
3737
"""Removes position embeddings and replaces the attention function and attention mask
3838
according as per :class:`~composer.algorithms.alibi.alibi.Alibi`. Note that the
@@ -83,10 +83,10 @@ def apply_alibi(
8383
``max_sequence_length``. For example,
8484
``composer.algorithms.alibi._gpt2_alibi.enlarge_mask``. Default: ``None``,
8585
which means no modification of the model's default attention mask.
86-
optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``.
87-
All optimizers that have already been constructed with
88-
``model.parameters()`` must be specified here so they will optimize
89-
the correct parameters. Default: ``None``.
86+
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional):
87+
Existing optimizers bound to ``model.parameters()``. All optimizers that have already been
88+
constructed with ``model.parameters()`` must be specified here so
89+
they will optimize the correct parameters.
9090
9191
If the optimizer(s) are constructed *after* calling this function,
9292
then it is safe to omit this parameter. These optimizers will see the correct

composer/algorithms/augmix/augmix.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616
from composer.algorithms.utils import augmentation_sets
1717
from composer.algorithms.utils.augmentation_common import map_pillow_function
18-
from composer.core.event import Event
19-
from composer.core.types import Algorithm, Event, State
18+
from composer.core import Algorithm, Event, State
2019
from composer.datasets.utils import add_vision_dataset_transform
2120
from composer.loggers import Logger
2221

composer/algorithms/blurpool/blurpool.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import functools
66
import logging
7-
from typing import Optional
7+
from typing import Optional, Sequence, Union
88

99
import numpy as np
1010
import torch
11+
from torch.optim import Optimizer
1112

1213
from composer.algorithms.blurpool.blurpool_layers import BlurConv2d, BlurMaxPool2d
1314
from composer.core import Algorithm, Event, State
14-
from composer.core.types import Optimizers
1515
from composer.loggers import Logger
1616
from composer.utils import module_surgery
1717

@@ -22,7 +22,7 @@ def apply_blurpool(model: torch.nn.Module,
2222
replace_convs: bool = True,
2323
replace_maxpools: bool = True,
2424
blur_first: bool = True,
25-
optimizers: Optional[Optimizers] = None) -> torch.nn.Module:
25+
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
2626
"""Add anti-aliasing filters to the strided :class:`torch.nn.Conv2d` and/or :class:`torch.nn.MaxPool2d` modules
2727
within `model`.
2828
@@ -41,8 +41,8 @@ def apply_blurpool(model: torch.nn.Module,
4141
overhead (though more closely matching
4242
`the paper <http://proceedings.mlr.press/v97/zhang19a.html>`_).
4343
See :class:`.BlurConv2d` for further discussion. Default: ``True``.
44-
optimizers (Optimizers, optional): Existing optimizers bound to
45-
``model.parameters()``. All optimizers that have already been
44+
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional):
45+
Existing optimizers bound to ``model.parameters()``. All optimizers that have already been
4646
constructed with ``model.parameters()`` must be specified here so
4747
they will optimize the correct parameters.
4848

composer/algorithms/channels_last/channels_last.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import torch
99

10-
from composer.core.types import Algorithm, Event, State
10+
from composer.core import Algorithm, Event, State
1111
from composer.loggers import Logger
1212

1313
log = logging.getLogger(__name__)

composer/algorithms/colout/colout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
import torch
1313
from PIL.Image import Image as PillowImage
14+
from torch import Tensor
1415
from torchvision.datasets import VisionDataset
1516

1617
from composer.algorithms.utils.augmentation_common import image_as_type
1718
from composer.core import Algorithm, Event, State
18-
from composer.core.types import Tensor
1919
from composer.datasets.utils import add_vision_dataset_transform
2020
from composer.loggers import Logger
2121

composer/algorithms/cutmix/cutmix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010
import numpy as np
1111
import torch
12+
from torch import Tensor
1213
from torch.nn import functional as F
1314

14-
from composer.core.types import Algorithm, Event, State, Tensor
15+
from composer.core import Algorithm, Event, State
1516
from composer.loggers import Logger
1617
from composer.models.loss import _check_for_index_targets
1718

composer/algorithms/cutout/cutout.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
import numpy as np
1111
import torch
1212
from PIL.Image import Image as PillowImage
13+
from torch import Tensor
1314

1415
from composer.algorithms.utils.augmentation_common import image_as_type
15-
from composer.core.types import Algorithm, Event, State, Tensor
16+
from composer.core import Algorithm, Event, State
1617
from composer.loggers import Logger
1718

1819
log = logging.getLogger(__name__)

composer/algorithms/factorize/factorize.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import Optional, Type, Union, cast
6+
from typing import Optional, Sequence, Type, Union, cast
77

88
import torch
9+
from torch.optim import Optimizer
910

1011
from composer.algorithms.factorize.factorize_modules import (FactorizedConv2d, FactorizedLinear,
1112
factorizing_could_speedup)
1213
from composer.core import Algorithm, Event, State
13-
from composer.core.types import Optimizers
1414
from composer.loggers import Logger
1515
from composer.utils import module_surgery
1616

@@ -27,7 +27,7 @@ def apply_factorization(model: torch.nn.Module,
2727
latent_channels: Union[int, float] = 0.25,
2828
min_features: int = 512,
2929
latent_features: Union[int, float] = 0.25,
30-
optimizers: Optional[Optimizers] = None) -> torch.nn.Module:
30+
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
3131
"""Replaces :class:`~torch.nn.Linear` and :class:`~torch.nn.Conv2d` modules and with
3232
:class:`~composer.algorithms.factorize.FactorizedLinear` and
3333
:class:`~composer.algorithms.factorize.FactorizedConv2d` modules.
@@ -62,8 +62,8 @@ def apply_factorization(model: torch.nn.Module,
6262
``min(in_features, out_features)`` for each :class:`~torch.nn.Linear`
6363
module, and is converted to the equivalent integer value, with a
6464
minimum of 1. Default: ``0.25``.
65-
optimizers (Optimizers, optional): Existing optimizers bound to
66-
``model.parameters()``. All optimizers that have already been
65+
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional):
66+
Existing optimizers bound to ``model.parameters()``. All optimizers that have already been
6767
constructed with ``model.parameters()`` must be specified here so
6868
they will optimize the correct parameters.
6969
@@ -217,7 +217,7 @@ def _python_log_surgery_result(model: torch.nn.Module, new_class: Type[torch.nn.
217217
def _factorize_conv2d_modules(model: torch.nn.Module,
218218
min_channels: int = 512,
219219
latent_channels: Union[int, float] = 0.25,
220-
optimizers: Optional[Optimizers] = None):
220+
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None):
221221
"""Replaces :class:`~torch.nn.Conv2d` modules in ``model`` with
222222
:class:`~composer.algorithms.factorize.FactorizedConv2d` modules.
223223
@@ -241,7 +241,7 @@ def _maybe_replace_conv2d(module: torch.nn.Module, module_index: int) -> Optiona
241241
def _factorize_linear_modules(model: torch.nn.Module,
242242
min_features: int = 512,
243243
latent_features: Union[int, float] = 0.25,
244-
optimizers: Optional[Optimizers] = None):
244+
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None):
245245
"""Replaces :class:`~torch.nn.Linear` modules in ``model`` with
246246
:class:`~composer.algorithms.factorize.FactorizedLinear` modules.
247247

composer/algorithms/ghost_batchnorm/ghost_batchnorm.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import Optional
6+
from typing import Optional, Sequence, Union
77

88
import numpy as np
99
import torch
10+
from torch.optim import Optimizer
1011

1112
from composer.core import Algorithm, Event, State
12-
from composer.core.types import Optimizers
1313
from composer.loggers import Logger
1414
from composer.utils import module_surgery
1515

@@ -20,7 +20,7 @@
2020

2121
def apply_ghost_batchnorm(model: torch.nn.Module,
2222
ghost_batch_size: int = 32,
23-
optimizers: Optional[Optimizers] = None) -> torch.nn.Module:
23+
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
2424
"""Replace batch normalization modules with ghost batch normalization modules.
2525
2626
Ghost batch normalization modules split their input into chunks of
@@ -30,10 +30,10 @@ def apply_ghost_batchnorm(model: torch.nn.Module,
3030
Args:
3131
model (torch.nn.Module): the model to modify in-place
3232
ghost_batch_size (int, optional): size of sub-batches to normalize over. Default: ``32``.
33-
optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``.
34-
All optimizers that have already been constructed with
35-
``model.parameters()`` must be specified here so they will optimize
36-
the correct parameters.
33+
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional):
34+
Existing optimizers bound to ``model.parameters()``. All optimizers that have already been
35+
constructed with ``model.parameters()`` must be specified here so
36+
they will optimize the correct parameters.
3737
3838
If the optimizer(s) are constructed *after* calling this function,
3939
then it is safe to omit this parameter. These optimizers will see the correct

0 commit comments

Comments
 (0)