Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,34 @@
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
from .norm import (
GroupNorm,
GroupNorm1,
LayerNorm,
LayerNorm2d,
LayerNormFp32,
LayerNorm2dFp32,
RmsNorm,
RmsNorm2d,
RmsNormFp32,
RmsNorm2dFp32,
SimpleNorm,
SimpleNorm2d,
SimpleNormFp32,
SimpleNorm2dFp32,
)
from .norm_act import (
BatchNormAct2d,
GroupNormAct,
GroupNorm1Act,
LayerNormAct,
LayerNormAct2d,
LayerNormActFp32,
LayerNormAct2dFp32,
RmsNormAct,
RmsNormAct2d,
RmsNormActFp32,
RmsNormAct2dFp32,
SyncBatchNormAct,
convert_sync_batchnorm,
FrozenBatchNormAct2d,
Expand Down
13 changes: 9 additions & 4 deletions timm/layers/create_act.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
""" Activation Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Union, Callable, Type
from typing import Callable, Optional, Type, Union

from .activations import *
from .activations_me import *
from .config import is_exportable, is_scriptable
from .typing import LayerType

# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
Expand Down Expand Up @@ -88,7 +89,7 @@
a.setdefault('hardswish', a.get('hard_swish'))


def get_act_fn(name: Union[Callable, str] = 'relu'):
def get_act_fn(name: Optional[LayerType] = 'relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
Expand All @@ -106,7 +107,7 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
return _ACT_FN_DEFAULT[name]


def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
def get_act_layer(name: Optional[LayerType] = 'relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
Expand All @@ -125,7 +126,11 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
return _ACT_LAYER_DEFAULT[name]


def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):
def create_act_layer(
name: Optional[LayerType],
inplace: Optional[bool] = None,
**kwargs
):
act_layer = get_act_layer(name)
if act_layer is None:
return None
Expand Down
23 changes: 22 additions & 1 deletion timm/layers/create_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,22 @@

import torch.nn as nn

from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
from .norm import (
GroupNorm,
GroupNorm1,
LayerNorm,
LayerNorm2d,
LayerNormFp32,
LayerNorm2dFp32,
RmsNorm,
RmsNorm2d,
RmsNormFp32,
RmsNorm2dFp32,
SimpleNorm,
SimpleNorm2d,
SimpleNormFp32,
SimpleNorm2dFp32,
)
from torchvision.ops.misc import FrozenBatchNorm2d

_NORM_MAP = dict(
Expand All @@ -21,10 +36,16 @@
groupnorm1=GroupNorm1,
layernorm=LayerNorm,
layernorm2d=LayerNorm2d,
layernormfp32=LayerNormFp32,
layernorm2dfp32=LayerNorm2dFp32,
rmsnorm=RmsNorm,
rmsnorm2d=RmsNorm2d,
rmsnormfp32=RmsNormFp32,
rmsnorm2dfp32=RmsNorm2dFp32,
simplenorm=SimpleNorm,
simplenorm2d=SimpleNorm2d,
simplenormfp32=SimpleNormFp32,
simplenorm2dfp32=SimpleNorm2dFp32,
frozenbatchnorm2d=FrozenBatchNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
Expand Down
74 changes: 56 additions & 18 deletions timm/layers/create_norm_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,35 @@
"""
import types
import functools
from typing import Optional

from .evo_norm import *
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, RmsNormAct, RmsNormAct2d
from .norm_act import (
BatchNormAct2d,
GroupNormAct,
GroupNorm1Act,
LayerNormAct,
LayerNormActFp32,
LayerNormAct2d,
LayerNormAct2dFp32,
RmsNormAct,
RmsNormActFp32,
RmsNormAct2d,
RmsNormAct2dFp32,
)
from .inplace_abn import InplaceAbn
from .typing import LayerType

_NORM_ACT_MAP = dict(
batchnorm=BatchNormAct2d,
batchnorm2d=BatchNormAct2d,
groupnorm=GroupNormAct,
groupnorm1=functools.partial(GroupNormAct, num_groups=1),
groupnorm1=GroupNorm1Act,
layernorm=LayerNormAct,
layernorm2d=LayerNormAct2d,
layernormfp32=LayerNormActFp32,
layernorm2dfp32=LayerNormAct2dFp32,
evonormb0=EvoNorm2dB0,
evonormb1=EvoNorm2dB1,
evonormb2=EvoNorm2dB2,
Expand All @@ -36,30 +52,62 @@
iabn=InplaceAbn,
rmsnorm=RmsNormAct,
rmsnorm2d=RmsNormAct2d,
rmsnormfp32=RmsNormActFp32,
rmsnorm2dfp32=RmsNormAct2dFp32,
)
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
# Reverse map from base norm layer names to norm+act layer classes
_NORM_TO_NORM_ACT_MAP = dict(
batchnorm=BatchNormAct2d,
batchnorm2d=BatchNormAct2d,
groupnorm=GroupNormAct,
groupnorm1=GroupNorm1Act,
layernorm=LayerNormAct,
layernorm2d=LayerNormAct2d,
layernormfp32=LayerNormActFp32,
layernorm2dfp32=LayerNormAct2dFp32,
rmsnorm=RmsNormAct,
rmsnorm2d=RmsNormAct2d,
rmsnormfp32=RmsNormActFp32,
rmsnorm2dfp32=RmsNormAct2dFp32,
)
# has act_layer arg to define act type
_NORM_ACT_REQUIRES_ARG = {
BatchNormAct2d,
GroupNormAct,
GroupNorm1Act,
LayerNormAct,
LayerNormAct2d,
LayerNormActFp32,
LayerNormAct2dFp32,
FilterResponseNormAct2d,
InplaceAbn,
RmsNormAct,
RmsNormAct2d,
RmsNormActFp32,
RmsNormAct2dFp32,
}


def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
def create_norm_act_layer(
layer_name: LayerType,
num_features: int,
act_layer: Optional[LayerType] = None,
apply_act: bool = True,
jit: bool = False,
**kwargs,
):
layer = get_norm_act_layer(layer_name, act_layer=act_layer)
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
if jit:
layer_instance = torch.jit.script(layer_instance)
return layer_instance


def get_norm_act_layer(norm_layer, act_layer=None):
def get_norm_act_layer(
norm_layer: LayerType,
act_layer: Optional[LayerType] = None,
):
if norm_layer is None:
return None
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
Expand All @@ -82,26 +130,16 @@ def get_norm_act_layer(norm_layer, act_layer=None):
# if function type, must be a lambda/fn that creates a norm_act layer
norm_act_layer = norm_layer
else:
# Use reverse map to find the corresponding norm+act layer
type_name = norm_layer.__name__.lower()
if type_name.startswith('batchnorm'):
norm_act_layer = BatchNormAct2d
elif type_name.startswith('groupnorm'):
norm_act_layer = GroupNormAct
elif type_name.startswith('groupnorm1'):
norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
elif type_name.startswith('layernorm2d'):
norm_act_layer = LayerNormAct2d
elif type_name.startswith('layernorm'):
norm_act_layer = LayerNormAct
elif type_name.startswith('rmsnorm2d'):
norm_act_layer = RmsNormAct2d
else:
assert False, f"No equivalent norm_act layer for {type_name}"
norm_act_layer = _NORM_TO_NORM_ACT_MAP.get(type_name, None)
assert norm_act_layer is not None, f"No equivalent norm_act layer for {type_name}"

if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
# pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
norm_act_kwargs.setdefault('act_layer', act_layer)
if norm_act_kwargs:
norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args

return norm_act_layer
Loading