diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 5c474fefcb..95595cd1b4 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -69,13 +69,34 @@ from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d +from .norm import ( + GroupNorm, + GroupNorm1, + LayerNorm, + LayerNorm2d, + LayerNormFp32, + LayerNorm2dFp32, + RmsNorm, + RmsNorm2d, + RmsNormFp32, + RmsNorm2dFp32, + SimpleNorm, + SimpleNorm2d, + SimpleNormFp32, + SimpleNorm2dFp32, +) from .norm_act import ( BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d, + LayerNormActFp32, + LayerNormAct2dFp32, + RmsNormAct, + RmsNormAct2d, + RmsNormActFp32, + RmsNormAct2dFp32, SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, diff --git a/timm/layers/create_act.py b/timm/layers/create_act.py index c734785d13..b066871078 100644 --- a/timm/layers/create_act.py +++ b/timm/layers/create_act.py @@ -1,11 +1,12 @@ """ Activation Factory Hacked together by / Copyright 2020 Ross Wightman """ -from typing import Union, Callable, Type +from typing import Callable, Optional, Type, Union from .activations import * from .activations_me import * from .config import is_exportable, is_scriptable +from .typing import LayerType # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. @@ -88,7 +89,7 @@ a.setdefault('hardswish', a.get('hard_swish')) -def get_act_fn(name: Union[Callable, str] = 'relu'): +def get_act_fn(name: Optional[LayerType] = 'relu'): """ Activation Function Factory Fetching activation fns by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. @@ -106,7 +107,7 @@ def get_act_fn(name: Union[Callable, str] = 'relu'): return _ACT_FN_DEFAULT[name] -def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): +def get_act_layer(name: Optional[LayerType] = 'relu'): """ Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. @@ -125,7 +126,11 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): return _ACT_LAYER_DEFAULT[name] -def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs): +def create_act_layer( + name: Optional[LayerType], + inplace: Optional[bool] = None, + **kwargs +): act_layer = get_act_layer(name) if act_layer is None: return None diff --git a/timm/layers/create_norm.py b/timm/layers/create_norm.py index 75262b5eca..0821be7fce 100644 --- a/timm/layers/create_norm.py +++ b/timm/layers/create_norm.py @@ -10,7 +10,22 @@ import torch.nn as nn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d +from .norm import ( + GroupNorm, + GroupNorm1, + LayerNorm, + LayerNorm2d, + LayerNormFp32, + LayerNorm2dFp32, + RmsNorm, + RmsNorm2d, + RmsNormFp32, + RmsNorm2dFp32, + SimpleNorm, + SimpleNorm2d, + SimpleNormFp32, + SimpleNorm2dFp32, +) from torchvision.ops.misc import FrozenBatchNorm2d _NORM_MAP = dict( @@ -21,10 +36,16 @@ groupnorm1=GroupNorm1, layernorm=LayerNorm, layernorm2d=LayerNorm2d, + layernormfp32=LayerNormFp32, + layernorm2dfp32=LayerNorm2dFp32, rmsnorm=RmsNorm, rmsnorm2d=RmsNorm2d, + rmsnormfp32=RmsNormFp32, + rmsnorm2dfp32=RmsNorm2dFp32, simplenorm=SimpleNorm, simplenorm2d=SimpleNorm2d, + simplenormfp32=SimpleNormFp32, + simplenorm2dfp32=SimpleNorm2dFp32, frozenbatchnorm2d=FrozenBatchNorm2d, ) _NORM_TYPES = {m for n, m in _NORM_MAP.items()} diff --git a/timm/layers/create_norm_act.py b/timm/layers/create_norm_act.py index e7805e7993..4b8d2bf2ff 100644 --- a/timm/layers/create_norm_act.py +++ b/timm/layers/create_norm_act.py @@ -8,19 +8,35 @@ """ import types import functools +from typing import Optional from .evo_norm import * from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d -from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, RmsNormAct, RmsNormAct2d +from .norm_act import ( + BatchNormAct2d, + GroupNormAct, + GroupNorm1Act, + LayerNormAct, + LayerNormActFp32, + LayerNormAct2d, + LayerNormAct2dFp32, + RmsNormAct, + RmsNormActFp32, + RmsNormAct2d, + RmsNormAct2dFp32, +) from .inplace_abn import InplaceAbn +from .typing import LayerType _NORM_ACT_MAP = dict( batchnorm=BatchNormAct2d, batchnorm2d=BatchNormAct2d, groupnorm=GroupNormAct, - groupnorm1=functools.partial(GroupNormAct, num_groups=1), + groupnorm1=GroupNorm1Act, layernorm=LayerNormAct, layernorm2d=LayerNormAct2d, + layernormfp32=LayerNormActFp32, + layernorm2dfp32=LayerNormAct2dFp32, evonormb0=EvoNorm2dB0, evonormb1=EvoNorm2dB1, evonormb2=EvoNorm2dB2, @@ -36,22 +52,51 @@ iabn=InplaceAbn, rmsnorm=RmsNormAct, rmsnorm2d=RmsNormAct2d, + rmsnormfp32=RmsNormActFp32, + rmsnorm2dfp32=RmsNormAct2dFp32, ) _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} +# Reverse map from base norm layer names to norm+act layer classes +_NORM_TO_NORM_ACT_MAP = dict( + batchnorm=BatchNormAct2d, + batchnorm2d=BatchNormAct2d, + groupnorm=GroupNormAct, + groupnorm1=GroupNorm1Act, + layernorm=LayerNormAct, + layernorm2d=LayerNormAct2d, + layernormfp32=LayerNormActFp32, + layernorm2dfp32=LayerNormAct2dFp32, + rmsnorm=RmsNormAct, + rmsnorm2d=RmsNormAct2d, + rmsnormfp32=RmsNormActFp32, + rmsnorm2dfp32=RmsNormAct2dFp32, +) # has act_layer arg to define act type _NORM_ACT_REQUIRES_ARG = { BatchNormAct2d, GroupNormAct, + GroupNorm1Act, LayerNormAct, LayerNormAct2d, + LayerNormActFp32, + LayerNormAct2dFp32, FilterResponseNormAct2d, InplaceAbn, RmsNormAct, RmsNormAct2d, + RmsNormActFp32, + RmsNormAct2dFp32, } -def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): +def create_norm_act_layer( + layer_name: LayerType, + num_features: int, + act_layer: Optional[LayerType] = None, + apply_act: bool = True, + jit: bool = False, + **kwargs, +): layer = get_norm_act_layer(layer_name, act_layer=act_layer) layer_instance = layer(num_features, apply_act=apply_act, **kwargs) if jit: @@ -59,7 +104,10 @@ def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=Tr return layer_instance -def get_norm_act_layer(norm_layer, act_layer=None): +def get_norm_act_layer( + norm_layer: LayerType, + act_layer: Optional[LayerType] = None, +): if norm_layer is None: return None assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) @@ -82,21 +130,10 @@ def get_norm_act_layer(norm_layer, act_layer=None): # if function type, must be a lambda/fn that creates a norm_act layer norm_act_layer = norm_layer else: + # Use reverse map to find the corresponding norm+act layer type_name = norm_layer.__name__.lower() - if type_name.startswith('batchnorm'): - norm_act_layer = BatchNormAct2d - elif type_name.startswith('groupnorm'): - norm_act_layer = GroupNormAct - elif type_name.startswith('groupnorm1'): - norm_act_layer = functools.partial(GroupNormAct, num_groups=1) - elif type_name.startswith('layernorm2d'): - norm_act_layer = LayerNormAct2d - elif type_name.startswith('layernorm'): - norm_act_layer = LayerNormAct - elif type_name.startswith('rmsnorm2d'): - norm_act_layer = RmsNormAct2d - else: - assert False, f"No equivalent norm_act layer for {type_name}" + norm_act_layer = _NORM_TO_NORM_ACT_MAP.get(type_name, None) + assert norm_act_layer is not None, f"No equivalent norm_act layer for {type_name}" if norm_act_layer in _NORM_ACT_REQUIRES_ARG: # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. @@ -104,4 +141,5 @@ def get_norm_act_layer(norm_layer, act_layer=None): norm_act_kwargs.setdefault('act_layer', act_layer) if norm_act_kwargs: norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args + return norm_act_layer diff --git a/timm/layers/norm.py b/timm/layers/norm.py index 3a6e87524e..79a19bf323 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -12,8 +12,14 @@ import torch.nn.functional as F from .fast_norm import ( - is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d, - fast_simple_norm, simple_norm + is_fast_norm, + fast_group_norm, + fast_layer_norm, + fast_rms_norm, + rms_norm2d, + fast_rms_norm2d, + fast_simple_norm, + simple_norm, ) try: @@ -25,9 +31,16 @@ class GroupNorm(nn.GroupNorm): _fast_norm: torch.jit.Final[bool] - def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): + def __init__( + self, + num_channels: int, + num_groups: int = 32, + eps: float = 1e-5, + affine: bool = True, + **kwargs, + ): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN - super().__init__(num_groups, num_channels, eps=eps, affine=affine) + super().__init__(num_groups, num_channels, eps=eps, affine=affine, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x): @@ -43,7 +56,7 @@ class GroupNorm1(nn.GroupNorm): """ _fast_norm: torch.jit.Final[bool] - def __init__(self, num_channels, **kwargs): + def __init__(self, num_channels: int, **kwargs): super().__init__(1, num_channels, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) @@ -59,8 +72,14 @@ class LayerNorm(nn.LayerNorm): """ _fast_norm: torch.jit.Final[bool] - def __init__(self, num_channels, eps=1e-6, affine=True): - super().__init__(num_channels, eps=eps, elementwise_affine=affine) + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -71,12 +90,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class LayerNormFp32(nn.LayerNorm): + """ LayerNorm + """ + + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + return x + + class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial NCHW tensors """ _fast_norm: torch.jit.Final[bool] - def __init__(self, num_channels, eps=1e-6, affine=True): - super().__init__(num_channels, eps=eps, elementwise_affine=affine) + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -89,6 +132,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class LayerNorm2dFp32(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial NCHW tensors """ + + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + x = x.permute(0, 3, 1, 2) + return x + + def _is_contiguous(tensor: torch.Tensor) -> bool: # jit is oh so lovely :/ if torch.jit.is_scripting(): @@ -121,7 +183,7 @@ class LayerNormExp2d(nn.LayerNorm): layout. However, benefits are not always clear and can perform worse on other GPUs. """ - def __init__(self, num_channels, eps=1e-6): + def __init__(self, num_channels: int, eps: float = 1e-6): super().__init__(num_channels, eps=eps) def forward(self, x) -> torch.Tensor: @@ -142,7 +204,14 @@ class RmsNorm(nn.Module): elementwise_affine: bool _fast_norm: bool - def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels @@ -167,7 +236,7 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE fast norm fallback needs our rms norm impl, so both paths through here. - # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. + # Since there is no built-in PyTorch impl, always uses APEX RmsNorm if installed. if self._fast_norm: x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) else: @@ -175,6 +244,48 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class RmsNormFp32(nn.Module): + """ RmsNorm w/ fast (apex) norm if available + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + return x + + class RmsNorm2d(nn.Module): """ RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available @@ -188,7 +299,14 @@ class RmsNorm2d(nn.Module): elementwise_affine: bool _fast_norm: bool - def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels @@ -221,6 +339,52 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class RmsNorm2dFp32(nn.Module): + """ RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available + + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something + like https://github.com/pytorch/pytorch/pull/150576 lands. + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + return x + + class SimpleNorm(nn.Module): """ SimpleNorm (x / std(x)) """ @@ -230,7 +394,14 @@ class SimpleNorm(nn.Module): elementwise_affine: bool _fast_norm: bool - def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels @@ -261,6 +432,48 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class SimpleNormFp32(nn.Module): + """ SimpleNorm (x / std(x)) + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + return x + + class SimpleNorm2d(nn.Module): """ SimpleNorm for NCHW tensors """ @@ -270,7 +483,14 @@ class SimpleNorm2d(nn.Module): elementwise_affine: bool _fast_norm: bool - def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels @@ -301,3 +521,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = simple_norm(x, self.normalized_shape, self.weight, self.eps) x = x.permute(0, 3, 1, 2) return x + + +class SimpleNorm2dFp32(nn.Module): + """ SimpleNorm for NCHW tensors + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + x = x.permute(0, 3, 1, 2) + return x diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index 42eb2da609..7dbb5e0f2c 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -12,7 +12,7 @@ Hacked together by / Copyright 2022 Ross Wightman """ -from typing import Union, List, Optional, Any +from typing import Any, Dict, List, Optional, Type, Union import torch from torch import nn as nn @@ -21,9 +21,17 @@ from ._fx import register_notrace_module from .create_act import create_act_layer -from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d +from .fast_norm import ( + is_fast_norm, + fast_group_norm, + fast_layer_norm, + fast_rms_norm, + rms_norm2d, + fast_rms_norm2d, +) from .norm import RmsNorm, RmsNorm2d from .trace_utils import _assert +from .typing import LayerType try: from torch.nn.functional import rms_norm @@ -31,7 +39,12 @@ from .fast_norm import rms_norm -def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True): +def _create_act( + act_layer: LayerType, + act_kwargs: Dict[str, Any] = None, + inplace: Optional[bool] = False, + apply_act: bool = True, +) -> nn.Module: act_kwargs = act_kwargs or {} act_kwargs.setdefault('inplace', inplace) act = None @@ -50,16 +63,16 @@ class BatchNormAct2d(nn.BatchNorm2d): """ def __init__( self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - track_running_stats=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, device=None, dtype=None, ): @@ -208,11 +221,11 @@ def __init__( self, num_features: int, eps: float = 1e-5, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, ): super().__init__() self.eps = eps @@ -344,7 +357,7 @@ def unfreeze_batch_norm_2d(module): return res -def _num_groups(num_channels, num_groups, group_size): +def _num_groups(num_channels: int, num_groups: int, group_size: int): if group_size: assert num_channels % group_size == 0 return num_channels // group_size @@ -352,19 +365,21 @@ def _num_groups(num_channels, num_groups, group_size): class GroupNormAct(nn.GroupNorm): + _fast_norm: torch.jit.Final[bool] + # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args def __init__( self, - num_channels, - num_groups=32, - eps=1e-5, - affine=True, - group_size=None, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + num_groups: int = 32, + eps: float = 1e-5, + affine: bool = True, + group_size: Optional[int] = None, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, ): super(GroupNormAct, self).__init__( _num_groups(num_channels, num_groups, group_size), @@ -388,16 +403,18 @@ def forward(self, x): class GroupNorm1Act(nn.GroupNorm): + _fast_norm: torch.jit.Final[bool] + def __init__( self, - num_channels, - eps=1e-5, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, ): super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() @@ -416,18 +433,21 @@ def forward(self, x): class LayerNormAct(nn.LayerNorm): + _fast_norm: torch.jit.Final[bool] + def __init__( self, normalization_shape: Union[int, List[int], torch.Size], - eps=1e-5, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + eps: float = 1e-5, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, ): - super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) + super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) @@ -443,19 +463,47 @@ def forward(self, x): return x +class LayerNormActFp32(nn.LayerNorm): + + def __init__( + self, + normalization_shape: Union[int, List[int], torch.Size], + eps: float = 1e-5, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, + ): + super().__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + + def forward(self, x): + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + x = self.drop(x) + x = self.act(x) + return x + + class LayerNormAct2d(nn.LayerNorm): + _fast_norm: torch.jit.Final[bool] + def __init__( self, - num_channels, - eps=1e-5, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, ): - super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() @@ -472,6 +520,33 @@ def forward(self, x): return x +class LayerNormAct2dFp32(nn.LayerNorm): + + def __init__( + self, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + x = x.permute(0, 3, 1, 2) + x = self.drop(x) + x = self.act(x) + return x + + class RmsNormAct(RmsNorm): """ RMSNorm + Activation for '2D' NCHW tensors @@ -481,16 +556,17 @@ class RmsNormAct(RmsNorm): """ def __init__( self, - num_channels, - eps=1e-6, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, ): - super().__init__(channels=num_channels, eps=eps, affine=affine) + super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() @@ -505,6 +581,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class RmsNormActFp32(RmsNorm): + """ RMSNorm + Activation for '2D' NCHW tensors + + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something + like https://github.com/pytorch/pytorch/pull/150576 lands. + """ + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, + ): + super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + x = self.drop(x) + x = self.act(x) + return x + + class RmsNormAct2d(RmsNorm2d): """ RMSNorm + Activation for '2D' NCHW tensors @@ -514,14 +620,14 @@ class RmsNormAct2d(RmsNorm2d): """ def __init__( self, - num_channels, - eps=1e-6, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, ): super().__init__(channels=num_channels, eps=eps, affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() @@ -536,3 +642,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.drop(x) x = self.act(x) return x + + +class RmsNormAct2dFp32(RmsNorm2d): + """ RMSNorm + Activation for '2D' NCHW tensors + + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something + like https://github.com/pytorch/pytorch/pull/150576 lands. + """ + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + ): + super().__init__(channels=num_channels, eps=eps, affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + x = self.drop(x) + x = self.act(x) + return x