diff --git a/README.md b/README.md index 85394bd66e..c453adfcf3 100644 --- a/README.md +++ b/README.md @@ -508,6 +508,7 @@ All model architecture families include variants with pretrained weights. There * Res2Net - https://arxiv.org/abs/1904.01169 * ResNeSt - https://arxiv.org/abs/2004.08955 * ReXNet - https://arxiv.org/abs/2007.00992 +* ROPE-ViT - https://arxiv.org/abs/2403.13298 * SelecSLS - https://arxiv.org/abs/1907.00837 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586 * Sequencer2D - https://arxiv.org/abs/2205.01972 diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index ac24140ea5..5c474fefcb 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -1,6 +1,20 @@ +from ._fx import ( + create_feature_extractor, + get_graph_node_names, + register_notrace_function, + register_notrace_module, + is_notrace_module, + is_notrace_function, + get_notrace_modules, + get_notrace_functions, +) from .activations import * -from .adaptive_avgmax_pool import \ - adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .adaptive_avgmax_pool import ( + adaptive_avgmax_pool2d, + select_adaptive_pool2d, + AdaptiveAvgMaxPool2d, + SelectAdaptivePool2d, +) from .attention import Attention, AttentionRope, maybe_add_mask from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention_pool import AttentionPoolLatent @@ -8,9 +22,19 @@ from .blur_pool import BlurPool2d, create_aa from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer -from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ - set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \ - set_reentrant_ckpt, use_reentrant_ckpt +from .config import ( + is_exportable, + is_scriptable, + is_no_jit, + use_fused_attn, + set_exportable, + set_scriptable, + set_no_jit, + set_layer_config, + set_fused_attn, + set_reentrant_ckpt, + use_reentrant_ckpt, +) from .conv2d_same import Conv2dSame, conv2d_same from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn @@ -20,8 +44,17 @@ from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn -from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ - EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .evo_norm import ( + EvoNorm2dB0, + EvoNorm2dB1, + EvoNorm2dB2, + EvoNorm2dS0, + EvoNorm2dS0a, + EvoNorm2dS1, + EvoNorm2dS1a, + EvoNorm2dS2, + EvoNorm2dS2a, +) from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to @@ -37,19 +70,50 @@ from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ - SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d +from .norm_act import ( + BatchNormAct2d, + GroupNormAct, + GroupNorm1Act, + LayerNormAct, + LayerNormAct2d, + SyncBatchNormAct, + convert_sync_batchnorm, + FrozenBatchNormAct2d, + freeze_batch_norm_2d, + unfreeze_batch_norm_2d, +) from .padding import get_padding, get_same_padding, pad_same from .patch_dropout import PatchDropout from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed from .pool1d import global_pool_nlc from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc -from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \ - resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit -from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ - build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ - FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat +from .pos_embed_rel import ( + RelPosMlp, + RelPosBias, + RelPosBiasTf, + gen_relative_position_index, + gen_relative_log_coords, + resize_rel_pos_bias_table, + resize_rel_pos_bias_table_simple, + resize_rel_pos_bias_table_levit, +) +from .pos_embed_sincos import ( + pixel_freq_bands, + freq_bands, + build_sincos2d_pos_embed, + build_fourier_pos_embed, + build_rotary_pos_embed, + apply_rot_embed, + apply_rot_embed_cat, + apply_rot_embed_list, + apply_keep_indices_nlc, + FourierEmbed, + RotaryEmbedding, + RotaryEmbeddingCat, + RotaryEmbeddingMixed, + get_mixed_freqs, +) from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernel from .separable_conv import SeparableConv2d, SeparableConvNormAct @@ -60,5 +124,11 @@ from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int from .typing import LayerType, PadType -from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \ - init_weight_jax, init_weight_vit +from .weight_init import ( + trunc_normal_, + trunc_normal_tf_, + variance_scaling_, + lecun_normal_, + init_weight_jax, + init_weight_vit, +) diff --git a/timm/layers/_fx.py b/timm/layers/_fx.py new file mode 100644 index 0000000000..ae5b701d68 --- /dev/null +++ b/timm/layers/_fx.py @@ -0,0 +1,81 @@ +from typing import Callable, Dict, List, Optional, Union, Tuple, Type + +import torch +from torch import nn + +try: + # NOTE we wrap torchvision fns to use timm leaf / no trace definitions + from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor + from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False + + +__all__ = [ + 'register_notrace_module', + 'is_notrace_module', + 'get_notrace_modules', + 'register_notrace_function', + 'is_notrace_function', + 'get_notrace_functions', + 'create_feature_extractor', + 'get_graph_node_names', +] + +# modules to treat as leafs when tracing +_leaf_modules = set() + + +def register_notrace_module(module: Type[nn.Module]): + """ + Any module not under timm.models.layers should get this decorator if we don't want to trace through it. + """ + _leaf_modules.add(module) + return module + + +def is_notrace_module(module: Type[nn.Module]): + return module in _leaf_modules + + +def get_notrace_modules(): + return list(_leaf_modules) + + +# Functions we want to autowrap (treat them as leaves) +_autowrap_functions = set() + + +def register_notrace_function(name_or_fn): + _autowrap_functions.add(name_or_fn) + return name_or_fn + + +def is_notrace_function(func: Callable): + return func in _autowrap_functions + + +def get_notrace_functions(): + return list(_autowrap_functions) + + +def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]: + return _get_graph_node_names( + model, + tracer_kwargs={ + 'leaf_modules': list(_leaf_modules), + 'autowrap_functions': list(_autowrap_functions) + } + ) + + +def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + return _create_feature_extractor( + model, return_nodes, + tracer_kwargs={ + 'leaf_modules': list(_leaf_modules), + 'autowrap_functions': list(_autowrap_functions) + } + ) \ No newline at end of file diff --git a/timm/layers/attention.py b/timm/layers/attention.py index 9144f3faae..2fb394e844 100644 --- a/timm/layers/attention.py +++ b/timm/layers/attention.py @@ -4,10 +4,13 @@ from torch import nn as nn from torch.nn import functional as F +from ._fx import register_notrace_function from .config import use_fused_attn from .pos_embed_sincos import apply_rot_embed_cat +@torch.fx.wrap +@register_notrace_function def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): return scores if attn_mask is None else scores + attn_mask diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index e6c7041760..7fc3b962eb 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -from. config import use_fused_attn +from .config import use_fused_attn from .helpers import to_2tuple from .pos_embed import resample_abs_pos_embed from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding diff --git a/timm/layers/cond_conv2d.py b/timm/layers/cond_conv2d.py index 32f02c98a0..bdeb8666f0 100644 --- a/timm/layers/cond_conv2d.py +++ b/timm/layers/cond_conv2d.py @@ -12,6 +12,7 @@ from torch import nn as nn from torch.nn import functional as F +from ._fx import register_notrace_module from .helpers import to_2tuple from .conv2d_same import conv2d_same from .padding import get_padding_value @@ -30,6 +31,7 @@ def condconv_initializer(weight): return condconv_initializer +@register_notrace_module class CondConv2d(nn.Module): """ Conditionally Parameterized Convolution Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py diff --git a/timm/layers/conv2d_same.py b/timm/layers/conv2d_same.py index 7ac85b7938..f3d18495cc 100644 --- a/timm/layers/conv2d_same.py +++ b/timm/layers/conv2d_same.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from typing import Tuple, Optional +from ._fx import register_notrace_module from .config import is_exportable, is_scriptable from .padding import pad_same, pad_same_arg, get_padding_value @@ -27,6 +28,7 @@ def conv2d_same( return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) +@register_notrace_module class Conv2dSame(nn.Conv2d): """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions """ diff --git a/timm/layers/inplace_abn.py b/timm/layers/inplace_abn.py index a80889339e..74fefef88a 100644 --- a/timm/layers/inplace_abn.py +++ b/timm/layers/inplace_abn.py @@ -15,7 +15,10 @@ def inplace_abn(x, weight, bias, running_mean, running_var, def inplace_abn_sync(**kwargs): inplace_abn(**kwargs) +from ._fx import register_notrace_module + +@register_notrace_module class InplaceAbn(nn.Module): """Activated Batch Normalization diff --git a/timm/layers/non_local_attn.py b/timm/layers/non_local_attn.py index 670e8f2475..71fe208290 100644 --- a/timm/layers/non_local_attn.py +++ b/timm/layers/non_local_attn.py @@ -8,6 +8,7 @@ from torch import nn from torch.nn import functional as F +from ._fx import register_notrace_module from .conv_bn_act import ConvNormAct from .helpers import make_divisible from .trace_utils import _assert @@ -69,6 +70,7 @@ def reset_parameters(self): nn.init.constant_(m.bias, 0) +@register_notrace_module class BilinearAttnTransform(nn.Module): def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index dd9413105e..42eb2da609 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -19,6 +19,7 @@ from torch.nn import functional as F from torchvision.ops.misc import FrozenBatchNorm2d +from ._fx import register_notrace_module from .create_act import create_act_layer from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d from .norm import RmsNorm, RmsNorm2d @@ -39,6 +40,7 @@ def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True): return nn.Identity() if act is None else act +@register_notrace_module class BatchNormAct2d(nn.BatchNorm2d): """BatchNorm + Activation @@ -134,6 +136,7 @@ def forward(self, x): return x +@register_notrace_module class SyncBatchNormAct(nn.SyncBatchNorm): # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254) # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers @@ -191,6 +194,7 @@ def convert_sync_batchnorm(module, process_group=None): return module_output +@register_notrace_module class FrozenBatchNormAct2d(torch.nn.Module): """ BatchNormAct2d where the batch statistics and the affine parameters are fixed diff --git a/timm/layers/pool2d_same.py b/timm/layers/pool2d_same.py index 4c2a1c4471..9e9f3046ba 100644 --- a/timm/layers/pool2d_same.py +++ b/timm/layers/pool2d_same.py @@ -7,17 +7,25 @@ import torch.nn.functional as F from typing import List, Tuple, Optional +from ._fx import register_notrace_module from .helpers import to_2tuple from .padding import pad_same, get_padding_value -def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), - ceil_mode: bool = False, count_include_pad: bool = True): +def avg_pool2d_same( + x: torch.Tensor, + kernel_size: List[int], + stride: List[int], + padding: List[int] = (0, 0), + ceil_mode: bool = False, + count_include_pad: bool = True, +): # FIXME how to deal with count_include_pad vs not for external padding? x = pad_same(x, kernel_size, stride) return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) +@register_notrace_module class AvgPool2dSame(nn.AvgPool2d): """ Tensorflow like 'SAME' wrapper for 2D average pooling """ @@ -33,12 +41,18 @@ def forward(self, x): def max_pool2d_same( - x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), - dilation: List[int] = (1, 1), ceil_mode: bool = False): + x: torch.Tensor, + kernel_size: List[int], + stride: List[int], + padding: List[int] = (0, 0), + dilation: List[int] = (1, 1), + ceil_mode: bool = False, +): x = pad_same(x, kernel_size, stride, value=-float('inf')) return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) +@register_notrace_module class MaxPool2dSame(nn.MaxPool2d): """ Tensorflow like 'SAME' wrapper for 2D max pooling """ diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index 0d50207dd8..df12a1aa7e 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -9,11 +9,13 @@ import torch import torch.nn.functional as F -from .helpers import to_2tuple +from ._fx import register_notrace_function _logger = logging.getLogger(__name__) +@torch.fx.wrap +@register_notrace_function def resample_abs_pos_embed( posemb: torch.Tensor, new_size: List[int], @@ -57,6 +59,8 @@ def resample_abs_pos_embed( return posemb +@torch.fx.wrap +@register_notrace_function def resample_abs_pos_embed_nhwc( posemb: torch.Tensor, new_size: List[int], diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index f8f169d5c2..d0978c5610 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -8,6 +8,7 @@ import torch from torch import nn as nn +from ._fx import register_notrace_function from .grid import ndgrid from .trace_utils import _assert @@ -219,8 +220,6 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): def apply_rot_embed_cat(x: torch.Tensor, emb): sin_emb, cos_emb = emb.tensor_split(2, -1) - if sin_emb.ndim == 3: - return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x) return x * cos_emb + rot(x) * sin_emb @@ -351,6 +350,7 @@ def __init__( ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, + temperature=self.temperature, ) self.bands = None self.register_buffer( @@ -365,9 +365,8 @@ def __init__( ) def get_embed(self, shape: Optional[List[int]] = None): - if self.bands is not None: + if shape is not None and self.bands is not None: # rebuild embeddings every call, use if target shape changes - assert shape is not None return build_rotary_pos_embed( shape, self.bands, @@ -376,8 +375,10 @@ def get_embed(self, shape: Optional[List[int]] = None): grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, ) - else: + elif self.pos_embed_sin is not None and self.pos_embed_cos is not None: return self.pos_embed_sin, self.pos_embed_cos + else: + assert False, "get_embed() requires pre-computed pos embeds or valid shape w/ pre-computed bands" def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 @@ -446,6 +447,7 @@ def __init__( ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, + temperature=self.temperature, ) self.bands = None self.register_buffer( @@ -455,7 +457,7 @@ def __init__( ) def get_embed(self, shape: Optional[List[int]] = None): - if self.bands is not None and shape is not None: + if shape is not None and self.bands is not None: # rebuild embeddings every call, use if target shape changes embeds = build_rotary_pos_embed( shape, @@ -469,9 +471,172 @@ def get_embed(self, shape: Optional[List[int]] = None): elif self.pos_embed is not None: return self.pos_embed else: - assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands" + assert False, "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands" def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 pos_embed = self.get_embed(x.shape[2:]) return apply_rot_embed_cat(x, pos_embed) + + +def init_random_2d_freqs( + head_dim: int, + depth: int, + num_heads: int, + temperature: float = 10.0, + rotate: bool = True, + *, + device=None, + dtype=torch.float32, +) -> torch.Tensor: + """ Vectorised 2D ROPE frequencies with random rotation for mixed mode ROPE. + Returns: + Tensor (2, depth, num_heads, head_dim//2) + """ + # base magnitudes, shape: (head_dim//4,) + mag = 1.0 / (temperature ** (torch.arange(0, head_dim, 4, device=device, dtype=dtype) / head_dim)) + + # (1,1,L) so it broadcasts over both depth and heads + mag = mag.unsqueeze(0).unsqueeze(0) # (1,1,L) + + # random (or zero) rotation per head *and* per block + if rotate: + angles = torch.rand(depth, num_heads, 1, device=device, dtype=dtype) * 2 * torch.pi + else: + angles = torch.zeros(depth, num_heads, 1, device=device, dtype=dtype) + + # build (depth, num_heads, 2·L) == head_dim//2 on the last axis + fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(angles + torch.pi / 2)], dim=-1) + fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(angles + torch.pi / 2)], dim=-1) + + # (2, depth, num_heads, head_dim//2) + return torch.stack([fx, fy], dim=0) + + +@torch.fx.wrap +@register_notrace_function +def get_mixed_grid( + height: int, + width: int, + grid_indexing: str = 'ij', + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + x_pos, y_pos = torch.meshgrid( + torch.arange(height, dtype=dtype, device=device), + torch.arange(width, dtype=dtype, device=device), + indexing=grid_indexing, + ) + t_x = x_pos.flatten() + t_y = y_pos.flatten() + return t_x, t_y + + +def get_mixed_freqs( + freqs: torch.Tensor, + t_x: torch.Tensor, + t_y: torch.Tensor, +) -> torch.Tensor: + """Compute mixed (learnable) frequencies.""" + # Create position indices + dtype = freqs.dtype + freqs = freqs.float() + freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)) + freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)) + combined = freqs_x + freqs_y # shape: (num_heads, N, dim//4) + sin_emb = torch.sin(combined).repeat_interleave(2, -1) # (N, dim//2) + cos_emb = torch.cos(combined).repeat_interleave(2, -1) # (N, dim//2) + rope_embeds = torch.cat([sin_emb, cos_emb], dim=-1) # (num_heads, H*W, head_dim) + return rope_embeds.to(dtype) + + +class RotaryEmbeddingMixed(nn.Module): + """Rotary position embedding with depth-dependent learnable frequencies. + + This implementation supports mixed (learnable) ROPE. In mixed mode, + each transformer block has its own set of learnable frequency parameters. + + Based on 'Rotary Position Embedding for Vision: https://arxiv.org/abs/2403.13298)' + Compatible with original at https://github.com/naver-ai/rope-vit + """ + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + temperature: float = 10.0, + feat_shape: Optional[List[int]] = None, + grid_indexing: str = 'xy', + ): + """Initialize rotary embeddings. + + Args: + dim: Embedding dimension (should be divisible by 4) + depth: Number of transformer blocks + num_heads: Number of attention heads + temperature: Base for frequency computation + feat_shape: Spatial dimensions [H, W] if known in advance + grid_indexing: How to index grid positions ('xy' or 'ij') + """ + super().__init__() + self.dim = dim + self.depth = depth + self.num_heads = num_heads + self.temperature = temperature + self.feat_shape = feat_shape + self.grid_indexing = grid_indexing + + head_dim = dim // num_heads + assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}" + freqs = init_random_2d_freqs( + head_dim, + depth, + num_heads, + temperature=temperature, + rotate=True, + ) # (2, depth, num_heads, head_dim//2) + self.freqs = nn.Parameter(freqs) + if feat_shape is not None: + # cache pre-computed grid + t_x, t_y = get_mixed_grid( + feat_shape[0], + feat_shape[1], + grid_indexing=grid_indexing, + device=self.freqs.device + ) + self.register_buffer('t_x', t_x, persistent=False) + self.register_buffer('t_y', t_y, persistent=False) + else: + self.t_x = self.t_y = None + + def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: + """Generate rotary embeddings for the given spatial shape. + + Args: + shape: Spatial dimensions [H, W] + + Returns: + Tensor of shape (depth, H*W, dim) containing concatenated sin/cos embeddings + """ + if shape is not None: + t_x, t_y = get_mixed_grid( + shape[0], + shape[1], + grid_indexing=self.grid_indexing, + device=self.freqs.device + ) + elif self.t_x is not None and self.t_y is not None: + t_x, t_y = self.t_x, self.t_y + else: + assert False, "get_embed() requires pre-computed t_x/t_y or valid shape" + + return get_mixed_freqs(self.freqs, t_x, t_y) + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + pos_embed = self.get_embed(x.shape[2:]) + return apply_rot_embed_cat(x, pos_embed) + + def no_weight_decay(self): + """Exclude frequency parameters from weight decay.""" + return {'freqs'} diff --git a/timm/layers/std_conv.py b/timm/layers/std_conv.py index d896ba5c2f..287c60d516 100644 --- a/timm/layers/std_conv.py +++ b/timm/layers/std_conv.py @@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F +from ._fx import register_notrace_module from .padding import get_padding, get_padding_value, pad_same @@ -47,6 +48,7 @@ def forward(self, x): return x +@register_notrace_module class StdConv2dSame(nn.Conv2d): """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. @@ -102,6 +104,7 @@ def forward(self, x): return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) +@register_notrace_module class ScaledStdConv2dSame(nn.Conv2d): """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 0352d79a5a..9ddc3890c8 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -6,112 +6,33 @@ import torch from torch import nn +from timm.layers import ( + create_feature_extractor, + get_graph_node_names, + register_notrace_module, + register_notrace_function, + is_notrace_module, + is_notrace_function, + get_notrace_functions, + get_notrace_modules, + Format, + ) from ._features import _get_feature_info, _get_return_layers -try: - # NOTE we wrap torchvision fns to use timm leaf / no trace definitions - from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor - from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names - has_fx_feature_extraction = True -except ImportError: - has_fx_feature_extraction = False -# Layers we went to treat as leaf modules -from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format -from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, maybe_add_mask -from timm.layers.non_local_attn import BilinearAttnTransform -from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame -from timm.layers.norm_act import ( - BatchNormAct2d, - SyncBatchNormAct, - FrozenBatchNormAct2d, - GroupNormAct, - GroupNorm1Act, - LayerNormAct, - LayerNormAct2d -) -__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules', - 'register_notrace_function', 'is_notrace_function', 'get_notrace_functions', - 'create_feature_extractor', 'get_graph_node_names', 'FeatureGraphNet', 'GraphExtractNet'] - - -# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here -# BUT modules from timm.models should use the registration mechanism below -_leaf_modules = { - BilinearAttnTransform, # reason: flow control t <= 1 - # Reason: get_same_padding has a max which raises a control flow error - Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, - CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]), - BatchNormAct2d, - SyncBatchNormAct, - FrozenBatchNormAct2d, - GroupNormAct, - GroupNorm1Act, - LayerNormAct, - LayerNormAct2d, -} - -try: - from timm.layers import InplaceAbn - _leaf_modules.add(InplaceAbn) -except ImportError: - pass - - -def register_notrace_module(module: Type[nn.Module]): - """ - Any module not under timm.models.layers should get this decorator if we don't want to trace through it. - """ - _leaf_modules.add(module) - return module - - -def is_notrace_module(module: Type[nn.Module]): - return module in _leaf_modules - - -def get_notrace_modules(): - return list(_leaf_modules) - - -# Functions we want to autowrap (treat them as leaves) -_autowrap_functions = { - resample_abs_pos_embed, - resample_abs_pos_embed_nhwc, - maybe_add_mask, -} - - -def register_notrace_function(func: Callable): - """ - Decorator for functions which ought not to be traced through - """ - _autowrap_functions.add(func) - return func - - -def is_notrace_function(func: Callable): - return func in _autowrap_functions - - -def get_notrace_functions(): - return list(_autowrap_functions) - - -def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]: - return _get_graph_node_names( - model, - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} - ) - - -def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): - assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' - return _create_feature_extractor( - model, return_nodes, - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} - ) +__all__ = [ + 'register_notrace_module', + 'is_notrace_module', + 'get_notrace_modules', + 'register_notrace_function', + 'is_notrace_function', + 'get_notrace_functions', + 'create_feature_extractor', + 'get_graph_node_names', + 'FeatureGraphNet', + 'GraphExtractNet', +] class FeatureGraphNet(nn.Module): @@ -128,7 +49,6 @@ def __init__( return_dict: bool = False, ): super().__init__() - assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) diff --git a/timm/models/eva.py b/timm/models/eva.py index 4e500130e7..9a48835a55 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -26,10 +26,20 @@ year={2025} } +@inproceedings{heo2024rotary, + title={Rotary position embedding for vision transformer}, + author={Heo, Byeongho and Park, Song and Han, Dongyoon and Yun, Sangdoo}, + booktitle={European Conference on Computer Vision}, + pages={289--305}, + year={2024}, + organization={Springer} +} + This file contains a number of ViT variants the utilise ROPE position embeddings, SwiGLU and other additions: * EVA & EVA02 model implementations that evolved from BEiT, additional models in vision_transformer.py. * `timm` original SBB ViT w/ ROPE position embeddings * Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181) + * ROPE-ViT from Naver AI (https://arxiv.org/abs/2403.13298) Modifications by / Copyright 2023 Ross Wightman, original copyrights below """ @@ -45,8 +55,9 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \ - apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \ - global_pool_nlc, to_2tuple, use_fused_attn, AttentionRope, AttentionPoolLatent + RotaryEmbeddingMixed, apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, \ + resample_patch_embed, resample_abs_pos_embed, global_pool_nlc, to_2tuple, use_fused_attn, AttentionRope, \ + AttentionPoolLatent from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -429,10 +440,13 @@ def __init__( init_values: Optional[float] = None, class_token: bool = True, num_reg_tokens: int = 0, + no_embed_class: bool = False, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, + rope_mixed_mode: bool = False, rope_grid_offset: float = 0., rope_grid_indexing: str = 'ij', + rope_temperature: float = 10000., use_post_norm: bool = False, use_pre_transformer_norm: bool = False, use_post_transformer_norm: Optional[bool] = None, @@ -472,10 +486,13 @@ def __init__( init_values: Initial layer-scale values class_token: Use class token num_reg_tokens: Number of additional learnable 'register' tokens to add to the sequence + no_embed_class: Don't include position embeddings for class (or reg) tokens use_abs_pos_emb: Use absolute (learned) positional embeddings use_rot_pos_emb: Use rotary position embeddings + rope_mixed_mode: Use mixed mode ROPE with per-layer learnable frequencies rope_grid_offset: Offset for rotary position embedding grid rope_grid_indexing: Indexing mode for rotary position embeddings ('ij' or 'xy') + rope_temperature: Temperature parameter for ROPE frequency computation use_post_norm: Use post-norm transformer block type use_pre_transformer_norm: Use normalization layer before transformer blocks use_post_transformer_norm: Use normalization layer after transformer blocks @@ -493,6 +510,7 @@ def __init__( self.global_pool = global_pool self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.num_prefix_tokens = (1 if class_token else 0) + num_reg_tokens + self.no_embed_class = no_embed_class self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False @@ -527,8 +545,8 @@ def __init__( self.reg_token = nn.Parameter(torch.zeros(1, num_reg_tokens, embed_dim)) if num_reg_tokens else None self.cls_embed = class_token and self.reg_token is None - self.pos_embed = nn.Parameter( - torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None + num_pos_tokens = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.zeros(1, num_pos_tokens, embed_dim)) if use_abs_pos_emb else None self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( @@ -541,15 +559,30 @@ def __init__( if use_rot_pos_emb: ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None - self.rope = RotaryEmbeddingCat( - embed_dim // num_heads, - in_pixels=False, - feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, - ref_feat_shape=ref_feat_shape, - grid_offset=rope_grid_offset, - grid_indexing=rope_grid_indexing, - ) + if rope_mixed_mode: + self.rope_mixed = True + # Mixed mode to supports depth-dependent frequencies + self.rope = RotaryEmbeddingMixed( + dim=embed_dim, + depth=depth, + num_heads=num_heads, + temperature=rope_temperature, + feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, + grid_indexing=rope_grid_indexing, + ) + else: + self.rope_mixed = False + self.rope = RotaryEmbeddingCat( + dim=embed_dim // num_heads, + temperature=rope_temperature, + in_pixels=False, + feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, + ref_feat_shape=ref_feat_shape, + grid_offset=rope_grid_offset, + grid_indexing=rope_grid_indexing, + ) else: + self.rope_mixed = False self.rope = None self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity() @@ -634,6 +667,8 @@ def _init_weights(self, m: nn.Module) -> None: def no_weight_decay(self) -> Set[str]: """Parameters to exclude from weight decay.""" nwd = {'pos_embed', 'cls_token'} + if (rope := getattr(self, "rope", None)) and hasattr(rope, "no_weight_decay"): + return nwd | {f"rope.{p}" for p in rope.no_weight_decay()} return nwd @torch.jit.ignore @@ -675,7 +710,7 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: self.pos_embed, new_size=(H, W), old_size=prev_grid_size, - num_prefix_tokens=self.num_prefix_tokens, + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) else: pos_embed = None @@ -685,18 +720,24 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: pos_embed = self.pos_embed rot_pos_embed = self.rope.get_embed() if self.rope is not None else None + to_cat = [] if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - - if pos_embed is not None: - x = x + pos_embed - + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) if self.reg_token is not None: - to_cat = [] - if self.cls_token is not None: - to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) - x = torch.cat(to_cat + [x], dim=1) + + if self.no_embed_class: + # position embedding does not overlap with class / reg token + if pos_embed is not None: + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # pos_embed has entry for class / reg token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + if pos_embed is not None: + x = x + pos_embed x = self.pos_drop(x) @@ -741,13 +782,23 @@ def forward_intermediates( blocks = self.blocks else: blocks = self.blocks[:max_index + 1] - for i, blk in enumerate(blocks): - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(blk, x, rope=rot_pos_embed) - else: - x = blk(x, rope=rot_pos_embed) - if i in take_indices: - intermediates.append(self.norm(x) if norm else x) + # Handle depth-dependent embeddings for mixed mode + if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None: + for i, blk in enumerate(blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, rope=rot_pos_embed[i]) + else: + x = blk(x, rope=rot_pos_embed[i]) + if i in take_indices: + intermediates.append(self.norm(x) if norm else x) + else: + for i, blk in enumerate(blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, rope=rot_pos_embed) + else: + x = blk(x, rope=rot_pos_embed) + if i in take_indices: + intermediates.append(self.norm(x) if norm else x) # process intermediates if self.num_prefix_tokens: @@ -807,11 +858,23 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) x, rot_pos_embed = self._pos_embed(x) x = self.norm_pre(x) - for blk in self.blocks: - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(blk, x, rope=rot_pos_embed) - else: - x = blk(x, rope=rot_pos_embed) + + # Handle depth-dependent embeddings for mixed mode + if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None: + # rot_pos_embed has shape (depth, H*W, dim) for mixed mode + for i, blk in enumerate(self.blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, rope=rot_pos_embed[i]) + else: + x = blk(x, rope=rot_pos_embed[i]) + else: + # Standard path for non-mixed mode + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, rope=rot_pos_embed) + else: + x = blk(x, rope=rot_pos_embed) + x = self.norm(x) return x @@ -932,6 +995,7 @@ def checkpoint_filter_fn( Filtered state dictionary. """ out_dict = {} + # Standard EVA checkpoint processing state_dict = state_dict.get('model_ema', state_dict) state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('module', state_dict) @@ -960,7 +1024,7 @@ def checkpoint_filter_fn( continue k = k[len_prefix:] - if 'rope' in k: + if 'rope' in k and not k == 'rope.freqs': # fixed embedding no need to load buffer from checkpoint continue @@ -1301,6 +1365,80 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 448, 448), num_classes=0, ), + + # RoPE-ViT models from Naver + 'vit_small_patch16_rope_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_base_patch16_rope_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_large_patch16_rope_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_small_patch16_rope_mixed_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_base_patch16_rope_mixed_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_large_patch16_rope_mixed_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_small_patch16_rope_ape_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_base_patch16_rope_ape_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_large_patch16_rope_ape_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_small_patch16_rope_mixed_ape_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_base_patch16_rope_mixed_ape_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), + 'vit_large_patch16_rope_mixed_ape_224.naver_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + license='apache-2.0', + ), }) @@ -1805,3 +1943,297 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> E ) return _create_eva('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + +# RoPE-ViT models from https://github.com/naver-ai/rope-vit +@register_model +def vit_small_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Axial ViT-S/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + use_abs_pos_emb=False, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=100.0, + ) + model = _create_eva('vit_small_patch16_rope_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Axial ViT-B/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + attn_type='rope', + use_fc_norm=False, + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + use_abs_pos_emb=False, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=100.0, + ) + model = _create_eva('vit_base_patch16_rope_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Axial ViT-L/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + use_abs_pos_emb=False, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=100.0, + ) + model = _create_eva('vit_large_patch16_rope_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Mixed ViT-S/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + use_abs_pos_emb=False, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=10.0, + rope_mixed_mode=True, + ) + model = _create_eva('vit_small_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Mixed ViT-B/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + attn_type='rope', + init_values=1e-5, + class_token=True, + global_pool='token', + use_abs_pos_emb=False, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=10.0, + rope_mixed_mode=True, + ) + model = _create_eva('vit_base_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Mixed ViT-L/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + use_abs_pos_emb=False, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=10.0, + rope_mixed_mode=True, + ) + model = _create_eva('vit_large_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +# APE variants (with absolute position embeddings) +@register_model +def vit_small_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Axial + APE ViT-S/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + no_embed_class=True, + use_abs_pos_emb=True, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=100.0, + ) + model = _create_eva('vit_small_patch16_rope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Axial + APE ViT-B/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + no_embed_class=True, + use_abs_pos_emb=True, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=100.0, + ) + + model = _create_eva('vit_base_patch16_rope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Axial + APE ViT-L/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + no_embed_class=True, + use_abs_pos_emb=True, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=100.0, + ) + + model = _create_eva('vit_large_patch16_rope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Mixed + APE ViT-S/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + no_embed_class=True, + use_abs_pos_emb=True, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=10.0, + rope_mixed_mode=True, + ) + + model = _create_eva('vit_small_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Mixed + APE ViT-B/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + no_embed_class=True, + use_abs_pos_emb=True, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=10.0, + rope_mixed_mode=True, + ) + model = _create_eva('vit_base_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva: + """RoPE-Mixed + APE ViT-L/16 from https://github.com/naver-ai/rope-vit""" + model_args = dict( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + attn_type='rope', + qkv_bias=True, + init_values=1e-5, + class_token=True, + global_pool='token', + no_embed_class=True, + use_abs_pos_emb=True, + use_rot_pos_emb=True, + rope_grid_indexing='xy', + rope_temperature=10.0, + rope_mixed_mode=True, + ) + model = _create_eva('vit_large_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model +