@@ -21,7 +21,7 @@ def apply_squeeze_excite(
21
21
latent_channels : float = 64 ,
22
22
min_channels : int = 128 ,
23
23
optimizers : Optional [Union [Optimizer , Sequence [Optimizer ]]] = None ,
24
- ):
24
+ ) -> None :
25
25
"""Adds Squeeze-and-Excitation blocks (`Hu et al, 2019 <https://arxiv.org/abs/1709.01507>`_) after
26
26
:class:`torch.nn.Conv2d` layers.
27
27
@@ -50,9 +50,6 @@ def apply_squeeze_excite(
50
50
then it is safe to omit this parameter. These optimizers will see the correct
51
51
model parameters.
52
52
53
- Returns:
54
- The modified model
55
-
56
53
Example:
57
54
.. testcode::
58
55
@@ -73,8 +70,6 @@ def convert_module(module: torch.nn.Module, module_index: int):
73
70
74
71
module_surgery .replace_module_classes (model , optimizers = optimizers , policies = {torch .nn .Conv2d : convert_module })
75
72
76
- return model
77
-
78
73
79
74
class SqueezeExcite2d (torch .nn .Module ):
80
75
"""Squeeze-and-Excitation block from (`Hu et al, 2019 <https://arxiv.org/abs/1709.01507>`_)
@@ -164,10 +159,10 @@ def match(self, event: Event, state: State) -> bool:
164
159
return event == Event .INIT
165
160
166
161
def apply (self , event : Event , state : State , logger : Logger ) -> Optional [int ]:
167
- state . model = apply_squeeze_excite (state .model ,
168
- optimizers = state .optimizers ,
169
- latent_channels = self .latent_channels ,
170
- min_channels = self .min_channels )
162
+ apply_squeeze_excite (state .model ,
163
+ optimizers = state .optimizers ,
164
+ latent_channels = self .latent_channels ,
165
+ min_channels = self .min_channels )
171
166
layer_count = module_surgery .count_module_instances (state .model , SqueezeExciteConv2d )
172
167
173
168
log .info (f'Applied SqueezeExcite to model { state .model .__class__ .__name__ } '
0 commit comments