Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions timm/models/mobilenetv5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@

from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import (
SelectAdaptivePool2d, Linear, LayerType, PadType, RmsNorm2d, ConvNormAct, create_conv2d, get_norm_act_layer,
to_2tuple
SelectAdaptivePool2d,
Linear,
LayerType,
RmsNorm2d,
ConvNormAct,
create_conv2d,
get_norm_layer,
get_norm_act_layer,
to_2tuple,
)
from ._builder import build_model_with_cfg
from ._efficientnet_blocks import SqueezeExcite, UniversalInvertedResidual
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_act_layer
from ._efficientnet_builder import (
BlockArgs,
EfficientNetBuilder,
decode_arch_def,
efficientnet_init_weights,
round_channels,
)
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq, checkpoint
Expand Down Expand Up @@ -115,6 +127,7 @@ def __init__(
num_classes: int = 1000,
in_chans: int = 3,
stem_size: int = 16,
stem_bias: bool = False,
fix_stem: bool = False,
num_features: int = 2048,
pad_type: str = '',
Expand Down Expand Up @@ -155,7 +168,7 @@ def __init__(
"""
super().__init__()
act_layer = act_layer or nn.GELU
norm_layer = norm_layer or RmsNorm2d
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite
self.num_classes = num_classes
Expand All @@ -173,6 +186,7 @@ def __init__(
kernel_size=3,
stride=2,
padding=pad_type,
bias=stem_bias,
norm_layer=norm_layer,
act_layer=act_layer,
)
Expand Down Expand Up @@ -396,6 +410,7 @@ def __init__(
block_args: BlockArgs,
in_chans: int = 3,
stem_size: int = 64,
stem_bias: bool = False,
fix_stem: bool = False,
pad_type: str = '',
msfa_indices: Sequence[int] = (-2, -1),
Expand All @@ -412,7 +427,7 @@ def __init__(
):
super().__init__()
act_layer = act_layer or nn.GELU
norm_layer = norm_layer or RmsNorm2d
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
se_layer = se_layer or SqueezeExcite
self.num_classes = 0 # Exists to satisfy ._hub module APIs.
self.drop_rate = drop_rate
Expand All @@ -427,6 +442,7 @@ def __init__(
kernel_size=3,
stride=2,
padding=pad_type,
bias=stem_bias,
norm_layer=norm_layer,
act_layer=act_layer,
)
Expand Down Expand Up @@ -786,12 +802,14 @@ def _cfg(url: str = '', **kwargs):
# encoder-only configs
'mobilenetv5_300m_enc': _cfg(
#hf_hub_id='timm/',
mean=(0., 0., 0.), std=(1., 1., 1.),
input_size=(3, 768, 768),
num_classes=0),

# WIP classification configs for testing
'mobilenetv5_300m': _cfg(
# hf_hub_id='timm/',
mean=(0., 0., 0.), std=(1., 1., 1.),
input_size=(3, 768, 768),
num_classes=0),
'mobilenetv5_base.untrained': _cfg(
Expand Down