Skip to content
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ All model architecture families include variants with pretrained weights. There
* Res2Net - https://arxiv.org/abs/1904.01169
* ResNeSt - https://arxiv.org/abs/2004.08955
* ReXNet - https://arxiv.org/abs/2007.00992
* ROPE-ViT - https://arxiv.org/abs/2403.13298
* SelecSLS - https://arxiv.org/abs/1907.00837
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
* Sequencer2D - https://arxiv.org/abs/2205.01972
Expand Down
102 changes: 86 additions & 16 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,40 @@
from ._fx import (
create_feature_extractor,
get_graph_node_names,
register_notrace_function,
register_notrace_module,
is_notrace_module,
is_notrace_function,
get_notrace_modules,
get_notrace_functions,
)
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .adaptive_avgmax_pool import (
adaptive_avgmax_pool2d,
select_adaptive_pool2d,
AdaptiveAvgMaxPool2d,
SelectAdaptivePool2d,
)
from .attention import Attention, AttentionRope, maybe_add_mask
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
from .attention_pool import AttentionPoolLatent
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .blur_pool import BlurPool2d, create_aa
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \
set_reentrant_ckpt, use_reentrant_ckpt
from .config import (
is_exportable,
is_scriptable,
is_no_jit,
use_fused_attn,
set_exportable,
set_scriptable,
set_no_jit,
set_layer_config,
set_fused_attn,
set_reentrant_ckpt,
use_reentrant_ckpt,
)
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
Expand All @@ -20,8 +44,17 @@
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .evo_norm import (
EvoNorm2dB0,
EvoNorm2dB1,
EvoNorm2dB2,
EvoNorm2dS0,
EvoNorm2dS0a,
EvoNorm2dS1,
EvoNorm2dS1a,
EvoNorm2dS2,
EvoNorm2dS2a,
)
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
Expand All @@ -37,19 +70,50 @@
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .norm_act import (
BatchNormAct2d,
GroupNormAct,
GroupNorm1Act,
LayerNormAct,
LayerNormAct2d,
SyncBatchNormAct,
convert_sync_batchnorm,
FrozenBatchNormAct2d,
freeze_batch_norm_2d,
unfreeze_batch_norm_2d,
)
from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed
from .pool1d import global_pool_nlc
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
from .pos_embed_rel import (
RelPosMlp,
RelPosBias,
RelPosBiasTf,
gen_relative_position_index,
gen_relative_log_coords,
resize_rel_pos_bias_table,
resize_rel_pos_bias_table_simple,
resize_rel_pos_bias_table_levit,
)
from .pos_embed_sincos import (
pixel_freq_bands,
freq_bands,
build_sincos2d_pos_embed,
build_fourier_pos_embed,
build_rotary_pos_embed,
apply_rot_embed,
apply_rot_embed_cat,
apply_rot_embed_list,
apply_keep_indices_nlc,
FourierEmbed,
RotaryEmbedding,
RotaryEmbeddingCat,
RotaryEmbeddingMixed,
get_mixed_freqs,
)
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct
Expand All @@ -60,5 +124,11 @@
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .typing import LayerType, PadType
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \
init_weight_jax, init_weight_vit
from .weight_init import (
trunc_normal_,
trunc_normal_tf_,
variance_scaling_,
lecun_normal_,
init_weight_jax,
init_weight_vit,
)
81 changes: 81 additions & 0 deletions timm/layers/_fx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Callable, Dict, List, Optional, Union, Tuple, Type

import torch
from torch import nn

try:
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False


__all__ = [
'register_notrace_module',
'is_notrace_module',
'get_notrace_modules',
'register_notrace_function',
'is_notrace_function',
'get_notrace_functions',
'create_feature_extractor',
'get_graph_node_names',
]

# modules to treat as leafs when tracing
_leaf_modules = set()


def register_notrace_module(module: Type[nn.Module]):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""
_leaf_modules.add(module)
return module


def is_notrace_module(module: Type[nn.Module]):
return module in _leaf_modules


def get_notrace_modules():
return list(_leaf_modules)


# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()


def register_notrace_function(name_or_fn):
_autowrap_functions.add(name_or_fn)
return name_or_fn


def is_notrace_function(func: Callable):
return func in _autowrap_functions


def get_notrace_functions():
return list(_autowrap_functions)


def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
return _get_graph_node_names(
model,
tracer_kwargs={
'leaf_modules': list(_leaf_modules),
'autowrap_functions': list(_autowrap_functions)
}
)


def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
return _create_feature_extractor(
model, return_nodes,
tracer_kwargs={
'leaf_modules': list(_leaf_modules),
'autowrap_functions': list(_autowrap_functions)
}
)
3 changes: 3 additions & 0 deletions timm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from torch import nn as nn
from torch.nn import functional as F

from ._fx import register_notrace_function
from .config import use_fused_attn
from .pos_embed_sincos import apply_rot_embed_cat


@torch.fx.wrap
@register_notrace_function
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
return scores if attn_mask is None else scores + attn_mask

Expand Down
2 changes: 1 addition & 1 deletion timm/layers/attention_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.nn as nn

from. config import use_fused_attn
from .config import use_fused_attn
from .helpers import to_2tuple
from .pos_embed import resample_abs_pos_embed
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
Expand Down
2 changes: 2 additions & 0 deletions timm/layers/cond_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import nn as nn
from torch.nn import functional as F

from ._fx import register_notrace_module
from .helpers import to_2tuple
from .conv2d_same import conv2d_same
from .padding import get_padding_value
Expand All @@ -30,6 +31,7 @@ def condconv_initializer(weight):
return condconv_initializer


@register_notrace_module
class CondConv2d(nn.Module):
""" Conditionally Parameterized Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Expand Down
2 changes: 2 additions & 0 deletions timm/layers/conv2d_same.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn.functional as F
from typing import Tuple, Optional

from ._fx import register_notrace_module
from .config import is_exportable, is_scriptable
from .padding import pad_same, pad_same_arg, get_padding_value

Expand All @@ -27,6 +28,7 @@ def conv2d_same(
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)


@register_notrace_module
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
Expand Down
3 changes: 3 additions & 0 deletions timm/layers/inplace_abn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def inplace_abn(x, weight, bias, running_mean, running_var,
def inplace_abn_sync(**kwargs):
inplace_abn(**kwargs)

from ._fx import register_notrace_module


@register_notrace_module
class InplaceAbn(nn.Module):
"""Activated Batch Normalization

Expand Down
2 changes: 2 additions & 0 deletions timm/layers/non_local_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import nn
from torch.nn import functional as F

from ._fx import register_notrace_module
from .conv_bn_act import ConvNormAct
from .helpers import make_divisible
from .trace_utils import _assert
Expand Down Expand Up @@ -69,6 +70,7 @@ def reset_parameters(self):
nn.init.constant_(m.bias, 0)


@register_notrace_module
class BilinearAttnTransform(nn.Module):

def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
Expand Down
4 changes: 4 additions & 0 deletions timm/layers/norm_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.nn import functional as F
from torchvision.ops.misc import FrozenBatchNorm2d

from ._fx import register_notrace_module
from .create_act import create_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d
from .norm import RmsNorm, RmsNorm2d
Expand All @@ -39,6 +40,7 @@ def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
return nn.Identity() if act is None else act


@register_notrace_module
class BatchNormAct2d(nn.BatchNorm2d):
"""BatchNorm + Activation

Expand Down Expand Up @@ -134,6 +136,7 @@ def forward(self, x):
return x


@register_notrace_module
class SyncBatchNormAct(nn.SyncBatchNorm):
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
Expand Down Expand Up @@ -191,6 +194,7 @@ def convert_sync_batchnorm(module, process_group=None):
return module_output


@register_notrace_module
class FrozenBatchNormAct2d(torch.nn.Module):
"""
BatchNormAct2d where the batch statistics and the affine parameters are fixed
Expand Down
22 changes: 18 additions & 4 deletions timm/layers/pool2d_same.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,25 @@
import torch.nn.functional as F
from typing import List, Tuple, Optional

from ._fx import register_notrace_module
from .helpers import to_2tuple
from .padding import pad_same, get_padding_value


def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
ceil_mode: bool = False, count_include_pad: bool = True):
def avg_pool2d_same(
x: torch.Tensor,
kernel_size: List[int],
stride: List[int],
padding: List[int] = (0, 0),
ceil_mode: bool = False,
count_include_pad: bool = True,
):
# FIXME how to deal with count_include_pad vs not for external padding?
x = pad_same(x, kernel_size, stride)
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)


@register_notrace_module
class AvgPool2dSame(nn.AvgPool2d):
""" Tensorflow like 'SAME' wrapper for 2D average pooling
"""
Expand All @@ -33,12 +41,18 @@ def forward(self, x):


def max_pool2d_same(
x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
dilation: List[int] = (1, 1), ceil_mode: bool = False):
x: torch.Tensor,
kernel_size: List[int],
stride: List[int],
padding: List[int] = (0, 0),
dilation: List[int] = (1, 1),
ceil_mode: bool = False,
):
x = pad_same(x, kernel_size, stride, value=-float('inf'))
return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)


@register_notrace_module
class MaxPool2dSame(nn.MaxPool2d):
""" Tensorflow like 'SAME' wrapper for 2D max pooling
"""
Expand Down
Loading