Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*',
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*']
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*']
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*', '*_7b_*']
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*', '*_7b_*']
else:
EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*']
EXCLUDE_FILTERS = ['*enormous*', '*_7b_*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*', '*_7b_*']

EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*']
EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*', '*_7b_*']

TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
Expand Down
2 changes: 2 additions & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@
RotaryEmbedding,
RotaryEmbeddingCat,
RotaryEmbeddingMixed,
RotaryEmbeddingDinoV3,
get_mixed_freqs,
create_rope_embed,
)
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
Expand Down
8 changes: 6 additions & 2 deletions timm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
qk_norm: bool = False,
scale_norm: bool = False,
proj_bias: bool = True,
rotate_half: bool = False,
):
"""Initialize the Attention module.

Expand All @@ -136,6 +137,7 @@ def __init__(
norm_layer: Normalization layer constructor to use for QK and scale normalization
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
scale_norm: Enable normalization (scaling) of attention output with norm_layer
rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
"""
super().__init__()
if scale_norm or qk_norm:
Expand All @@ -148,6 +150,7 @@ def __init__(
self.scale = head_dim ** -0.5
self.num_prefix_tokens = num_prefix_tokens
self.fused_attn = use_fused_attn()
self.rotate_half = rotate_half

if qkv_fused:
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
Expand Down Expand Up @@ -196,8 +199,9 @@ def forward(

if rope is not None:
npt = self.num_prefix_tokens
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v)
half = getattr(self, 'rotate_half', False)
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)

if self.fused_attn:
x = F.scaled_dot_product_attention(
Expand Down
4 changes: 4 additions & 0 deletions timm/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,17 @@ def __init__(
norm_layer=None,
bias=True,
drop=0.,
align_to=0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)

if align_to:
hidden_features = hidden_features + (-hidden_features % align_to)

self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
Expand Down
Loading