From c08db89889255605f310772a5e73c924fedca052 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 4 Jul 2025 13:06:34 -0700 Subject: [PATCH 1/2] Add flag to enable float32 computation for normalization (norm + affine) in several timm norm & norm+act layers --- timm/layers/create_act.py | 13 ++- timm/layers/norm.py | 144 +++++++++++++++++++++++----- timm/layers/norm_act.py | 194 ++++++++++++++++++++++++-------------- 3 files changed, 250 insertions(+), 101 deletions(-) diff --git a/timm/layers/create_act.py b/timm/layers/create_act.py index c734785d13..b066871078 100644 --- a/timm/layers/create_act.py +++ b/timm/layers/create_act.py @@ -1,11 +1,12 @@ """ Activation Factory Hacked together by / Copyright 2020 Ross Wightman """ -from typing import Union, Callable, Type +from typing import Callable, Optional, Type, Union from .activations import * from .activations_me import * from .config import is_exportable, is_scriptable +from .typing import LayerType # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. @@ -88,7 +89,7 @@ a.setdefault('hardswish', a.get('hard_swish')) -def get_act_fn(name: Union[Callable, str] = 'relu'): +def get_act_fn(name: Optional[LayerType] = 'relu'): """ Activation Function Factory Fetching activation fns by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. @@ -106,7 +107,7 @@ def get_act_fn(name: Union[Callable, str] = 'relu'): return _ACT_FN_DEFAULT[name] -def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): +def get_act_layer(name: Optional[LayerType] = 'relu'): """ Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. @@ -125,7 +126,11 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): return _ACT_LAYER_DEFAULT[name] -def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs): +def create_act_layer( + name: Optional[LayerType], + inplace: Optional[bool] = None, + **kwargs +): act_layer = get_act_layer(name) if act_layer is None: return None diff --git a/timm/layers/norm.py b/timm/layers/norm.py index 3a6e87524e..7e9a18185c 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -12,8 +12,14 @@ import torch.nn.functional as F from .fast_norm import ( - is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d, - fast_simple_norm, simple_norm + is_fast_norm, + fast_group_norm, + fast_layer_norm, + fast_rms_norm, + rms_norm2d, + fast_rms_norm2d, + fast_simple_norm, + simple_norm, ) try: @@ -24,15 +30,27 @@ class GroupNorm(nn.GroupNorm): _fast_norm: torch.jit.Final[bool] - - def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): + _fp32_norm: torch.jit.Final[bool] + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + eps: float = 1e-5, + affine: bool = True, + use_fp32: bool = False, + **kwargs, + ): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN - super().__init__(num_groups, num_channels, eps=eps, affine=affine) + super().__init__(num_groups, num_channels, eps=eps, affine=affine, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fp32_norm = use_fp32 def forward(self, x): if self._fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + elif self._fp32_norm: + return F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.dtype) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) @@ -42,14 +60,18 @@ class GroupNorm1(nn.GroupNorm): Input: tensor in shape [B, C, *] """ _fast_norm: torch.jit.Final[bool] + _fp32_norm: torch.jit.Final[bool] - def __init__(self, num_channels, **kwargs): + def __init__(self, num_channels: int, use_fp32: bool = False, **kwargs): super().__init__(1, num_channels, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fp32_norm = use_fp32 def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + elif self._fp32_norm: + return F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.device) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) @@ -58,14 +80,25 @@ class LayerNorm(nn.LayerNorm): """ LayerNorm w/ fast norm option """ _fast_norm: torch.jit.Final[bool] - - def __init__(self, num_channels, eps=1e-6, affine=True): - super().__init__(num_channels, eps=eps, elementwise_affine=affine) + _fp32_norm: torch.jit.Final[bool] + + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + use_fp32: bool = False, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fp32_norm = use_fp32 def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self._fp32_norm: + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x @@ -74,15 +107,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial NCHW tensors """ _fast_norm: torch.jit.Final[bool] - - def __init__(self, num_channels, eps=1e-6, affine=True): - super().__init__(num_channels, eps=eps, elementwise_affine=affine) + _fp32_norm: torch.jit.Final[bool] + + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + use_fp32: bool = False, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fp32_norm = use_fp32 def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) if self._fast_norm: x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self._fp32_norm: + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) @@ -121,7 +165,7 @@ class LayerNormExp2d(nn.LayerNorm): layout. However, benefits are not always clear and can perform worse on other GPUs. """ - def __init__(self, num_channels, eps=1e-6): + def __init__(self, num_channels: int, eps: float = 1e-6): super().__init__(num_channels, eps=eps) def forward(self, x) -> torch.Tensor: @@ -136,13 +180,22 @@ def forward(self, x) -> torch.Tensor: class RmsNorm(nn.Module): """ RmsNorm w/ fast (apex) norm if available """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool _fast_norm: bool - - def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + _fp32_norm: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + use_fp32: bool = False, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels @@ -153,6 +206,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.eps = eps self.elementwise_affine = affine self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fp32_norm = use_fp32 if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -167,9 +221,11 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE fast norm fallback needs our rms norm impl, so both paths through here. - # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. + # Since there is no built-in PyTorch impl, always uses APEX RmsNorm if installed. if self._fast_norm: x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + elif self._fp32_norm: + x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = rms_norm(x, self.normalized_shape, self.weight, self.eps) return x @@ -182,13 +238,22 @@ class RmsNorm2d(nn.Module): on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something like https://github.com/pytorch/pytorch/pull/150576 lands. """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool _fast_norm: bool - - def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + _fp32_norm: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + use_fp32: bool = False, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels @@ -199,6 +264,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.eps = eps self.elementwise_affine = affine self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fp32_norm = use_fp32 if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -216,6 +282,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. if self._fast_norm: x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps) + elif self._fp32_norm: + x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps) return x @@ -224,13 +292,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SimpleNorm(nn.Module): """ SimpleNorm (x / std(x)) """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool _fast_norm: bool - - def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + _fp32_norm: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + use_fp32: bool = False, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels @@ -241,6 +318,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.eps = eps self.elementwise_affine = affine self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fp32_norm = use_fp32 if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -256,6 +334,8 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) + elif self._fp32_norm: + x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = simple_norm(x, self.normalized_shape, self.weight, self.eps) return x @@ -264,13 +344,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SimpleNorm2d(nn.Module): """ SimpleNorm for NCHW tensors """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool _fast_norm: bool - - def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + _fp32_norm: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + use_fp32: bool = False, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels @@ -281,6 +370,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.eps = eps self.elementwise_affine = affine self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fp32_norm = use_fp32 if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -297,6 +387,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) if self._fast_norm: x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) + elif self._fp32_norm: + x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = simple_norm(x, self.normalized_shape, self.weight, self.eps) x = x.permute(0, 3, 1, 2) diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index 42eb2da609..886067e125 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -12,7 +12,7 @@ Hacked together by / Copyright 2022 Ross Wightman """ -from typing import Union, List, Optional, Any +from typing import Any, Dict, List, Optional, Type, Union import torch from torch import nn as nn @@ -21,9 +21,17 @@ from ._fx import register_notrace_module from .create_act import create_act_layer -from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d +from .fast_norm import ( + is_fast_norm, + fast_group_norm, + fast_layer_norm, + fast_rms_norm, + rms_norm2d, + fast_rms_norm2d, +) from .norm import RmsNorm, RmsNorm2d from .trace_utils import _assert +from .typing import LayerType try: from torch.nn.functional import rms_norm @@ -31,7 +39,12 @@ from .fast_norm import rms_norm -def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True): +def _create_act( + act_layer: LayerType, + act_kwargs: Dict[str, Any] = None, + inplace: Optional[bool] = False, + apply_act: bool = True, +) -> nn.Module: act_kwargs = act_kwargs or {} act_kwargs.setdefault('inplace', inplace) act = None @@ -50,16 +63,16 @@ class BatchNormAct2d(nn.BatchNorm2d): """ def __init__( self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - track_running_stats=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, device=None, dtype=None, ): @@ -208,11 +221,11 @@ def __init__( self, num_features: int, eps: float = 1e-5, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, ): super().__init__() self.eps = eps @@ -344,7 +357,7 @@ def unfreeze_batch_norm_2d(module): return res -def _num_groups(num_channels, num_groups, group_size): +def _num_groups(num_channels: int, num_groups: int, group_size: int): if group_size: assert num_channels % group_size == 0 return num_channels // group_size @@ -352,19 +365,23 @@ def _num_groups(num_channels, num_groups, group_size): class GroupNormAct(nn.GroupNorm): + _fast_norm: torch.jit.Final[bool] + _fp32_norm: torch.jit.Final[bool] + # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args def __init__( self, - num_channels, - num_groups=32, - eps=1e-5, - affine=True, - group_size=None, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + num_groups: int = 32, + eps: float = 1e-5, + affine: bool = True, + group_size: Optional[int] = None, + use_fp32: bool = False, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, ): super(GroupNormAct, self).__init__( _num_groups(num_channels, num_groups, group_size), @@ -376,10 +393,13 @@ def __init__( self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() + self._fp32_norm = use_fp32 def forward(self, x): if self._fast_norm: x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + elif self._fp32_norm: + x = F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.drop(x) @@ -388,26 +408,33 @@ def forward(self, x): class GroupNorm1Act(nn.GroupNorm): + _fast_norm: torch.jit.Final[bool] + _fp32_norm: torch.jit.Final[bool] + def __init__( self, - num_channels, - eps=1e-5, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + use_fp32: bool = False, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, ): super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() + self._fp32_norm = use_fp32 def forward(self, x): if self._fast_norm: x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + elif self._fp32_norm: + x = F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.drop(x) @@ -416,26 +443,34 @@ def forward(self, x): class LayerNormAct(nn.LayerNorm): + _fast_norm: torch.jit.Final[bool] + _fp32_norm: torch.jit.Final[bool] + def __init__( self, normalization_shape: Union[int, List[int], torch.Size], - eps=1e-5, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + eps: float = 1e-5, + affine: bool = True, + use_fp32: bool = False, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, ): - super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) + super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() + self._fp32_norm = use_fp32 def forward(self, x): if self._fast_norm: x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self._fp32_norm: + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = self.drop(x) @@ -444,26 +479,34 @@ def forward(self, x): class LayerNormAct2d(nn.LayerNorm): + _fast_norm: torch.jit.Final[bool] + _fp32_norm: torch.jit.Final[bool] + def __init__( self, - num_channels, - eps=1e-5, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + use_fp32: bool = False, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, ): - super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) + super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() + self._fp32_norm = use_fp32 # for norm def forward(self, x): x = x.permute(0, 2, 3, 1) if self._fast_norm: x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self._fp32_norm: + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) @@ -481,23 +524,28 @@ class RmsNormAct(RmsNorm): """ def __init__( self, - num_channels, - eps=1e-6, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + use_fp32: bool = False, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, ): - super().__init__(channels=num_channels, eps=eps, affine=affine) + super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() + self._fp32_norm = use_fp32 # for norm def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + elif self._fp32_norm: + x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = rms_norm(x, self.normalized_shape, self.weight, self.eps) x = self.drop(x) @@ -514,23 +562,27 @@ class RmsNormAct2d(RmsNorm2d): """ def __init__( self, - num_channels, - eps=1e-6, - affine=True, - apply_act=True, - act_layer=nn.ReLU, - act_kwargs=None, - inplace=True, - drop_layer=None, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + use_fp32: bool = False, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, ): super().__init__(channels=num_channels, eps=eps, affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() + self._fp32_norm = use_fp32 # for norm def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps) + elif self._fp32_norm: + x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps) x = self.drop(x) From 5c720aa97ec3687f52c77ea198017d4a23f4dc20 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 4 Jul 2025 15:38:59 -0700 Subject: [PATCH 2/2] Change of plan, make separate 'Fp32' norm and norm+act instances to reduce change of regressions --- timm/layers/__init__.py | 23 ++- timm/layers/create_norm.py | 23 ++- timm/layers/create_norm_act.py | 74 +++++++--- timm/layers/norm.py | 260 +++++++++++++++++++++++++++------ timm/layers/norm_act.py | 141 ++++++++++++++---- 5 files changed, 428 insertions(+), 93 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 5c474fefcb..95595cd1b4 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -69,13 +69,34 @@ from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d +from .norm import ( + GroupNorm, + GroupNorm1, + LayerNorm, + LayerNorm2d, + LayerNormFp32, + LayerNorm2dFp32, + RmsNorm, + RmsNorm2d, + RmsNormFp32, + RmsNorm2dFp32, + SimpleNorm, + SimpleNorm2d, + SimpleNormFp32, + SimpleNorm2dFp32, +) from .norm_act import ( BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d, + LayerNormActFp32, + LayerNormAct2dFp32, + RmsNormAct, + RmsNormAct2d, + RmsNormActFp32, + RmsNormAct2dFp32, SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, diff --git a/timm/layers/create_norm.py b/timm/layers/create_norm.py index 75262b5eca..0821be7fce 100644 --- a/timm/layers/create_norm.py +++ b/timm/layers/create_norm.py @@ -10,7 +10,22 @@ import torch.nn as nn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d +from .norm import ( + GroupNorm, + GroupNorm1, + LayerNorm, + LayerNorm2d, + LayerNormFp32, + LayerNorm2dFp32, + RmsNorm, + RmsNorm2d, + RmsNormFp32, + RmsNorm2dFp32, + SimpleNorm, + SimpleNorm2d, + SimpleNormFp32, + SimpleNorm2dFp32, +) from torchvision.ops.misc import FrozenBatchNorm2d _NORM_MAP = dict( @@ -21,10 +36,16 @@ groupnorm1=GroupNorm1, layernorm=LayerNorm, layernorm2d=LayerNorm2d, + layernormfp32=LayerNormFp32, + layernorm2dfp32=LayerNorm2dFp32, rmsnorm=RmsNorm, rmsnorm2d=RmsNorm2d, + rmsnormfp32=RmsNormFp32, + rmsnorm2dfp32=RmsNorm2dFp32, simplenorm=SimpleNorm, simplenorm2d=SimpleNorm2d, + simplenormfp32=SimpleNormFp32, + simplenorm2dfp32=SimpleNorm2dFp32, frozenbatchnorm2d=FrozenBatchNorm2d, ) _NORM_TYPES = {m for n, m in _NORM_MAP.items()} diff --git a/timm/layers/create_norm_act.py b/timm/layers/create_norm_act.py index e7805e7993..4b8d2bf2ff 100644 --- a/timm/layers/create_norm_act.py +++ b/timm/layers/create_norm_act.py @@ -8,19 +8,35 @@ """ import types import functools +from typing import Optional from .evo_norm import * from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d -from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, RmsNormAct, RmsNormAct2d +from .norm_act import ( + BatchNormAct2d, + GroupNormAct, + GroupNorm1Act, + LayerNormAct, + LayerNormActFp32, + LayerNormAct2d, + LayerNormAct2dFp32, + RmsNormAct, + RmsNormActFp32, + RmsNormAct2d, + RmsNormAct2dFp32, +) from .inplace_abn import InplaceAbn +from .typing import LayerType _NORM_ACT_MAP = dict( batchnorm=BatchNormAct2d, batchnorm2d=BatchNormAct2d, groupnorm=GroupNormAct, - groupnorm1=functools.partial(GroupNormAct, num_groups=1), + groupnorm1=GroupNorm1Act, layernorm=LayerNormAct, layernorm2d=LayerNormAct2d, + layernormfp32=LayerNormActFp32, + layernorm2dfp32=LayerNormAct2dFp32, evonormb0=EvoNorm2dB0, evonormb1=EvoNorm2dB1, evonormb2=EvoNorm2dB2, @@ -36,22 +52,51 @@ iabn=InplaceAbn, rmsnorm=RmsNormAct, rmsnorm2d=RmsNormAct2d, + rmsnormfp32=RmsNormActFp32, + rmsnorm2dfp32=RmsNormAct2dFp32, ) _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} +# Reverse map from base norm layer names to norm+act layer classes +_NORM_TO_NORM_ACT_MAP = dict( + batchnorm=BatchNormAct2d, + batchnorm2d=BatchNormAct2d, + groupnorm=GroupNormAct, + groupnorm1=GroupNorm1Act, + layernorm=LayerNormAct, + layernorm2d=LayerNormAct2d, + layernormfp32=LayerNormActFp32, + layernorm2dfp32=LayerNormAct2dFp32, + rmsnorm=RmsNormAct, + rmsnorm2d=RmsNormAct2d, + rmsnormfp32=RmsNormActFp32, + rmsnorm2dfp32=RmsNormAct2dFp32, +) # has act_layer arg to define act type _NORM_ACT_REQUIRES_ARG = { BatchNormAct2d, GroupNormAct, + GroupNorm1Act, LayerNormAct, LayerNormAct2d, + LayerNormActFp32, + LayerNormAct2dFp32, FilterResponseNormAct2d, InplaceAbn, RmsNormAct, RmsNormAct2d, + RmsNormActFp32, + RmsNormAct2dFp32, } -def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): +def create_norm_act_layer( + layer_name: LayerType, + num_features: int, + act_layer: Optional[LayerType] = None, + apply_act: bool = True, + jit: bool = False, + **kwargs, +): layer = get_norm_act_layer(layer_name, act_layer=act_layer) layer_instance = layer(num_features, apply_act=apply_act, **kwargs) if jit: @@ -59,7 +104,10 @@ def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=Tr return layer_instance -def get_norm_act_layer(norm_layer, act_layer=None): +def get_norm_act_layer( + norm_layer: LayerType, + act_layer: Optional[LayerType] = None, +): if norm_layer is None: return None assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) @@ -82,21 +130,10 @@ def get_norm_act_layer(norm_layer, act_layer=None): # if function type, must be a lambda/fn that creates a norm_act layer norm_act_layer = norm_layer else: + # Use reverse map to find the corresponding norm+act layer type_name = norm_layer.__name__.lower() - if type_name.startswith('batchnorm'): - norm_act_layer = BatchNormAct2d - elif type_name.startswith('groupnorm'): - norm_act_layer = GroupNormAct - elif type_name.startswith('groupnorm1'): - norm_act_layer = functools.partial(GroupNormAct, num_groups=1) - elif type_name.startswith('layernorm2d'): - norm_act_layer = LayerNormAct2d - elif type_name.startswith('layernorm'): - norm_act_layer = LayerNormAct - elif type_name.startswith('rmsnorm2d'): - norm_act_layer = RmsNormAct2d - else: - assert False, f"No equivalent norm_act layer for {type_name}" + norm_act_layer = _NORM_TO_NORM_ACT_MAP.get(type_name, None) + assert norm_act_layer is not None, f"No equivalent norm_act layer for {type_name}" if norm_act_layer in _NORM_ACT_REQUIRES_ARG: # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. @@ -104,4 +141,5 @@ def get_norm_act_layer(norm_layer, act_layer=None): norm_act_kwargs.setdefault('act_layer', act_layer) if norm_act_kwargs: norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args + return norm_act_layer diff --git a/timm/layers/norm.py b/timm/layers/norm.py index 7e9a18185c..79a19bf323 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -30,7 +30,6 @@ class GroupNorm(nn.GroupNorm): _fast_norm: torch.jit.Final[bool] - _fp32_norm: torch.jit.Final[bool] def __init__( self, @@ -38,19 +37,15 @@ def __init__( num_groups: int = 32, eps: float = 1e-5, affine: bool = True, - use_fp32: bool = False, **kwargs, ): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) - self._fp32_norm = use_fp32 def forward(self, x): if self._fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) - elif self._fp32_norm: - return F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.dtype) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) @@ -60,18 +55,14 @@ class GroupNorm1(nn.GroupNorm): Input: tensor in shape [B, C, *] """ _fast_norm: torch.jit.Final[bool] - _fp32_norm: torch.jit.Final[bool] - def __init__(self, num_channels: int, use_fp32: bool = False, **kwargs): + def __init__(self, num_channels: int, **kwargs): super().__init__(1, num_channels, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) - self._fp32_norm = use_fp32 def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) - elif self._fp32_norm: - return F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.device) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) @@ -80,59 +71,86 @@ class LayerNorm(nn.LayerNorm): """ LayerNorm w/ fast norm option """ _fast_norm: torch.jit.Final[bool] - _fp32_norm: torch.jit.Final[bool] def __init__( self, num_channels: int, eps: float = 1e-6, affine: bool = True, - use_fp32: bool = False, **kwargs, ): super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) - self._fp32_norm = use_fp32 def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self._fp32_norm: - x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x +class LayerNormFp32(nn.LayerNorm): + """ LayerNorm + """ + + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + return x + + class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial NCHW tensors """ _fast_norm: torch.jit.Final[bool] - _fp32_norm: torch.jit.Final[bool] def __init__( self, num_channels: int, eps: float = 1e-6, affine: bool = True, - use_fp32: bool = False, **kwargs, ): super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) - self._fp32_norm = use_fp32 def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) if self._fast_norm: x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self._fp32_norm: - x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) return x +class LayerNorm2dFp32(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial NCHW tensors """ + + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + x = x.permute(0, 3, 1, 2) + return x + + def _is_contiguous(tensor: torch.Tensor) -> bool: # jit is oh so lovely :/ if torch.jit.is_scripting(): @@ -180,19 +198,17 @@ def forward(self, x) -> torch.Tensor: class RmsNorm(nn.Module): """ RmsNorm w/ fast (apex) norm if available """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool _fast_norm: bool - _fp32_norm: bool def __init__( self, channels: int, eps: float = 1e-6, affine: bool = True, - use_fp32: bool = False, device=None, dtype=None, ) -> None: @@ -206,7 +222,6 @@ def __init__( self.eps = eps self.elementwise_affine = affine self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) - self._fp32_norm = use_fp32 if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -224,13 +239,53 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Since there is no built-in PyTorch impl, always uses APEX RmsNorm if installed. if self._fast_norm: x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) - elif self._fp32_norm: - x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = rms_norm(x, self.normalized_shape, self.weight, self.eps) return x +class RmsNormFp32(nn.Module): + """ RmsNorm w/ fast (apex) norm if available + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + return x + + class RmsNorm2d(nn.Module): """ RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available @@ -238,19 +293,17 @@ class RmsNorm2d(nn.Module): on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something like https://github.com/pytorch/pytorch/pull/150576 lands. """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool _fast_norm: bool - _fp32_norm: bool def __init__( self, channels: int, eps: float = 1e-6, affine: bool = True, - use_fp32: bool = False, device=None, dtype=None, ) -> None: @@ -264,7 +317,6 @@ def __init__( self.eps = eps self.elementwise_affine = affine self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) - self._fp32_norm = use_fp32 if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -282,29 +334,71 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. if self._fast_norm: x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps) - elif self._fp32_norm: - x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps) return x +class RmsNorm2dFp32(nn.Module): + """ RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available + + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something + like https://github.com/pytorch/pytorch/pull/150576 lands. + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + return x + + class SimpleNorm(nn.Module): """ SimpleNorm (x / std(x)) """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool _fast_norm: bool - _fp32_norm: bool def __init__( self, channels: int, eps: float = 1e-6, affine: bool = True, - use_fp32: bool = False, device=None, dtype=None, ) -> None: @@ -318,7 +412,6 @@ def __init__( self.eps = eps self.elementwise_affine = affine self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) - self._fp32_norm = use_fp32 if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -334,29 +427,67 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) - elif self._fp32_norm: - x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = simple_norm(x, self.normalized_shape, self.weight, self.eps) return x +class SimpleNormFp32(nn.Module): + """ SimpleNorm (x / std(x)) + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + return x + + class SimpleNorm2d(nn.Module): """ SimpleNorm for NCHW tensors """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool _fast_norm: bool - _fp32_norm: bool def __init__( self, channels: int, eps: float = 1e-6, affine: bool = True, - use_fp32: bool = False, device=None, dtype=None, ) -> None: @@ -370,7 +501,6 @@ def __init__( self.eps = eps self.elementwise_affine = affine self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) - self._fp32_norm = use_fp32 if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -387,9 +517,51 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) if self._fast_norm: x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) - elif self._fp32_norm: - x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = simple_norm(x, self.normalized_shape, self.weight, self.eps) x = x.permute(0, 3, 1, 2) return x + + +class SimpleNorm2dFp32(nn.Module): + """ SimpleNorm for NCHW tensors + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + channels: int, + eps: float = 1e-6, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + x = x.permute(0, 3, 1, 2) + return x diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index 886067e125..7dbb5e0f2c 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -366,7 +366,6 @@ def _num_groups(num_channels: int, num_groups: int, group_size: int): class GroupNormAct(nn.GroupNorm): _fast_norm: torch.jit.Final[bool] - _fp32_norm: torch.jit.Final[bool] # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args def __init__( @@ -376,7 +375,6 @@ def __init__( eps: float = 1e-5, affine: bool = True, group_size: Optional[int] = None, - use_fp32: bool = False, apply_act: bool = True, act_layer: LayerType = nn.ReLU, act_kwargs: Dict[str, Any] = None, @@ -393,13 +391,10 @@ def __init__( self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() - self._fp32_norm = use_fp32 def forward(self, x): if self._fast_norm: x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) - elif self._fp32_norm: - x = F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.drop(x) @@ -409,14 +404,12 @@ def forward(self, x): class GroupNorm1Act(nn.GroupNorm): _fast_norm: torch.jit.Final[bool] - _fp32_norm: torch.jit.Final[bool] def __init__( self, num_channels: int, eps: float = 1e-5, affine: bool = True, - use_fp32: bool = False, apply_act: bool = True, act_layer: LayerType = nn.ReLU, act_kwargs: Dict[str, Any] = None, @@ -428,13 +421,10 @@ def __init__( self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() - self._fp32_norm = use_fp32 def forward(self, x): if self._fast_norm: x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) - elif self._fp32_norm: - x = F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.drop(x) @@ -444,14 +434,12 @@ def forward(self, x): class LayerNormAct(nn.LayerNorm): _fast_norm: torch.jit.Final[bool] - _fp32_norm: torch.jit.Final[bool] def __init__( self, normalization_shape: Union[int, List[int], torch.Size], eps: float = 1e-5, affine: bool = True, - use_fp32: bool = False, apply_act: bool = True, act_layer: LayerType = nn.ReLU, act_kwargs: Dict[str, Any] = None, @@ -464,13 +452,10 @@ def __init__( self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() - self._fp32_norm = use_fp32 def forward(self, x): if self._fast_norm: x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self._fp32_norm: - x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = self.drop(x) @@ -478,16 +463,39 @@ def forward(self, x): return x +class LayerNormActFp32(nn.LayerNorm): + + def __init__( + self, + normalization_shape: Union[int, List[int], torch.Size], + eps: float = 1e-5, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, + ): + super().__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + + def forward(self, x): + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + x = self.drop(x) + x = self.act(x) + return x + + class LayerNormAct2d(nn.LayerNorm): _fast_norm: torch.jit.Final[bool] - _fp32_norm: torch.jit.Final[bool] def __init__( self, num_channels: int, eps: float = 1e-5, affine: bool = True, - use_fp32: bool = False, apply_act: bool = True, act_layer: LayerType = nn.ReLU, act_kwargs: Dict[str, Any] = None, @@ -495,18 +503,15 @@ def __init__( drop_layer: Optional[Type[nn.Module]] = None, **kwargs, ): - super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() - self._fp32_norm = use_fp32 # for norm def forward(self, x): x = x.permute(0, 2, 3, 1) if self._fast_norm: x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self._fp32_norm: - x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) @@ -515,6 +520,33 @@ def forward(self, x): return x +class LayerNormAct2dFp32(nn.LayerNorm): + + def __init__( + self, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, + ): + super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + x = x.permute(0, 3, 1, 2) + x = self.drop(x) + x = self.act(x) + return x + + class RmsNormAct(RmsNorm): """ RMSNorm + Activation for '2D' NCHW tensors @@ -527,7 +559,6 @@ def __init__( num_channels: int, eps: float = 1e-6, affine: bool = True, - use_fp32: bool = False, apply_act: bool = True, act_layer: LayerType = nn.ReLU, act_kwargs: Dict[str, Any] = None, @@ -539,13 +570,10 @@ def __init__( self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() - self._fp32_norm = use_fp32 # for norm def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) - elif self._fp32_norm: - x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = rms_norm(x, self.normalized_shape, self.weight, self.eps) x = self.drop(x) @@ -553,6 +581,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class RmsNormActFp32(RmsNorm): + """ RMSNorm + Activation for '2D' NCHW tensors + + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something + like https://github.com/pytorch/pytorch/pull/150576 lands. + """ + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, + ): + super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + x = self.drop(x) + x = self.act(x) + return x + + class RmsNormAct2d(RmsNorm2d): """ RMSNorm + Activation for '2D' NCHW tensors @@ -565,7 +623,6 @@ def __init__( num_channels: int, eps: float = 1e-6, affine: bool = True, - use_fp32: bool = False, apply_act: bool = True, act_layer: LayerType = nn.ReLU, act_kwargs: Dict[str, Any] = None, @@ -576,15 +633,41 @@ def __init__( self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() - self._fp32_norm = use_fp32 # for norm def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps) - elif self._fp32_norm: - x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) else: x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps) x = self.drop(x) x = self.act(x) return x + + +class RmsNormAct2dFp32(RmsNorm2d): + """ RMSNorm + Activation for '2D' NCHW tensors + + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something + like https://github.com/pytorch/pytorch/pull/150576 lands. + """ + def __init__( + self, + num_channels: int, + eps: float = 1e-6, + affine: bool = True, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + ): + super().__init__(channels=num_channels, eps=eps, affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + x = self.drop(x) + x = self.act(x) + return x