Skip to content
8 changes: 7 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def test_model_backward(model_name, batch_size):
pytest.skip("Fixed input size model > limit.")

model = create_model(model_name, pretrained=False, num_classes=42)
encoder_only = model.num_classes == 0 # FIXME better approach?
num_params = sum([x.numel() for x in model.parameters()])
model.train()

Expand All @@ -224,7 +225,12 @@ def test_model_backward(model_name, batch_size):
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])

assert outputs.shape[-1] == 42
if encoder_only:
output_fmt = getattr(model, 'output_fmt', 'NCHW')
feat_axis = get_channel_dim(output_fmt)
assert outputs.shape[feat_axis] == model.num_features, f'unpooled feature dim {outputs.shape[feat_axis]} != model.num_features {model.num_features}'
else:
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'

Expand Down
16 changes: 14 additions & 2 deletions timm/layers/create_norm_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .evo_norm import *
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, RmsNormAct, RmsNormAct2d
from .inplace_abn import InplaceAbn

_NORM_ACT_MAP = dict(
Expand All @@ -34,11 +34,21 @@
frntlu=FilterResponseNormTlu2d,
inplaceabn=InplaceAbn,
iabn=InplaceAbn,
rmsnorm=RmsNormAct,
rmsnorm2d=RmsNormAct2d,
)
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
# has act_layer arg to define act type
_NORM_ACT_REQUIRES_ARG = {
BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
BatchNormAct2d,
GroupNormAct,
LayerNormAct,
LayerNormAct2d,
FilterResponseNormAct2d,
InplaceAbn,
RmsNormAct,
RmsNormAct2d,
}


def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
Expand Down Expand Up @@ -83,6 +93,8 @@ def get_norm_act_layer(norm_layer, act_layer=None):
norm_act_layer = LayerNormAct2d
elif type_name.startswith('layernorm'):
norm_act_layer = LayerNormAct
elif type_name.startswith('rmsnorm2d'):
norm_act_layer = RmsNormAct2d
else:
assert False, f"No equivalent norm_act layer for {type_name}"

Expand Down
47 changes: 46 additions & 1 deletion timm/layers/fast_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def fast_rms_norm(
return fused_rms_norm_affine(x, weight, normalized_shape, eps)

if is_autocast_enabled(x.device.type):
# normally native AMP casts LN inputs to float32
# normally native AMP casts LN inputs to float32 and leaves the output as float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)
Expand All @@ -162,6 +162,51 @@ def fast_rms_norm(
return x


def rms_norm2d(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
):
assert len(normalized_shape) == 1
v = x.pow(2)
v = torch.mean(v, dim=1, keepdim=True)
x = x * torch.rsqrt(v + eps)
if weight is not None:
x = x * weight.reshape(1, -1, 1, 1)
return x


def fast_rms_norm2d(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> torch.Tensor:
if torch.jit.is_scripting():
# this must be by itself, cannot merge with has_apex_rmsnorm
return rms_norm2d(x, normalized_shape, weight, eps)

if has_apex_rmsnorm:
x = x.permute(0, 2, 3, 1)
if weight is None:
x = fused_rms_norm(x, normalized_shape, eps)
else:
x = fused_rms_norm_affine(x, weight, normalized_shape, eps)
x = x.permute(0, 3, 1, 2)

if is_autocast_enabled(x.device.type):
# normally native AMP casts norm inputs to float32 and leaves the output as float32
# apex does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)

with torch.amp.autocast(device_type=x.device.type, enabled=False):
x = rms_norm2d(x, normalized_shape, weight, eps)

return x


def simple_norm(
x: torch.Tensor,
normalized_shape: List[int],
Expand Down
17 changes: 11 additions & 6 deletions timm/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import torch.nn as nn
import torch.nn.functional as F

from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm
from .fast_norm import (
is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d,
fast_simple_norm, simple_norm
)

try:
from torch.nn.functional import rms_norm
Expand Down Expand Up @@ -173,7 +176,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class RmsNorm2d(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
""" RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available

NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
like https://github.com/pytorch/pytorch/pull/150576 lands.
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
Expand Down Expand Up @@ -205,14 +212,12 @@ def reset_parameters(self) -> None:
nn.init.ones_(self.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
return x


Expand Down
74 changes: 73 additions & 1 deletion timm/layers/norm_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@
from torchvision.ops.misc import FrozenBatchNorm2d

from .create_act import create_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d
from .norm import RmsNorm, RmsNorm2d
from .trace_utils import _assert

try:
from torch.nn.functional import rms_norm
except ImportError:
from .fast_norm import rms_norm


def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
act_kwargs = act_kwargs or {}
Expand Down Expand Up @@ -460,3 +466,69 @@ def forward(self, x):
x = self.drop(x)
x = self.act(x)
return x


class RmsNormAct(RmsNorm):
""" RMSNorm + Activation for '2D' NCHW tensors

NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
like https://github.com/pytorch/pytorch/pull/150576 lands.
"""
def __init__(
self,
num_channels,
eps=1e-6,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
act_kwargs=None,
inplace=True,
drop_layer=None,
):
super().__init__(channels=num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
self._fast_norm = is_fast_norm()

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
x = self.drop(x)
x = self.act(x)
return x


class RmsNormAct2d(RmsNorm2d):
""" RMSNorm + Activation for '2D' NCHW tensors

NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
like https://github.com/pytorch/pytorch/pull/150576 lands.
"""
def __init__(
self,
num_channels,
eps=1e-6,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
act_kwargs=None,
inplace=True,
drop_layer=None,
):
super().__init__(channels=num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
self._fast_norm = is_fast_norm()

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm:
x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
x = self.drop(x)
x = self.act(x)
return x
2 changes: 2 additions & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .metaformer import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .mobilenetv5 import *
from .mobilevit import *
from .mvitv2 import *
from .naflexvit import *
Expand Down Expand Up @@ -129,6 +130,7 @@
load_model_config_from_hf as load_model_config_from_hf,
load_state_dict_from_hf as load_state_dict_from_hf,
push_to_hf_hub as push_to_hf_hub,
save_for_hf as save_for_hf,
)
from ._manipulate import (
model_parameters as model_parameters,
Expand Down
5 changes: 3 additions & 2 deletions timm/models/_efficientnet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,13 @@ def __init__(
value_dim=value_dim,
query_strides=query_strides,
kv_stride=kv_stride,
dw_kernel_size=dw_kernel_size,
dilation=dilation,
padding=pad_type,
dw_kernel_size=dw_kernel_size,
attn_drop=attn_drop,
proj_drop=proj_drop,
#bias=use_bias, # why not here if used w/ mhsa?
norm_layer=norm_layer,
# use_bias=use_bias, # why not here if used w/ mhsa?
)
else:
self.attn = Attention2d(
Expand Down
Loading