Skip to content

Commit 1ad1645

Browse files
committed
Merge branch 'contrastive-master'
2 parents 2319cbb + 51febd8 commit 1ad1645

File tree

3 files changed

+252
-3
lines changed

3 files changed

+252
-3
lines changed

tests/test_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
torch._C._jit_set_profiling_mode(False)
1515

1616
# transformer models don't support many of the spatial / feature based model functionalities
17-
NON_STD_FILTERS = ['vit_*']
17+
NON_STD_FILTERS = ['vit_*', 'tnt_*']
18+
NUM_NON_STD = len(NON_STD_FILTERS)
1819

1920
# exclude models that cause specific test failures
2021
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
@@ -31,7 +32,7 @@
3132

3233

3334
@pytest.mark.timeout(120)
34-
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1]))
35+
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD]))
3536
@pytest.mark.parametrize('batch_size', [1])
3637
def test_model_forward(model_name, batch_size):
3738
"""Run a single forward pass with each model"""

timm/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .efficientnet import *
77
from .gluon_resnet import *
88
from .gluon_xception import *
9+
from .hardcorenas import *
910
from .hrnet import *
1011
from .inception_resnet_v2 import *
1112
from .inception_v3 import *
@@ -23,13 +24,13 @@
2324
from .selecsls import *
2425
from .senet import *
2526
from .sknet import *
27+
from .tnt import *
2628
from .tresnet import *
2729
from .vgg import *
2830
from .vision_transformer import *
2931
from .vovnet import *
3032
from .xception import *
3133
from .xception_aligned import *
32-
from .hardcorenas import *
3334

3435
from .factory import create_model, split_model_name, safe_model_name
3536
from .helpers import load_checkpoint, resume_checkpoint, model_parameters

timm/models/tnt.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
""" Transformer in Transformer (TNT) in PyTorch
2+
3+
A PyTorch implement of TNT as described in
4+
'Transformer in Transformer' - https://arxiv.org/abs/2103.00112
5+
6+
The official mindspore code is released and available at
7+
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
8+
"""
9+
import math
10+
import torch
11+
import torch.nn as nn
12+
from functools import partial
13+
14+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15+
from timm.models.helpers import load_pretrained
16+
from timm.models.layers import DropPath, trunc_normal_
17+
from timm.models.vision_transformer import Mlp
18+
from timm.models.registry import register_model
19+
20+
21+
def _cfg(url='', **kwargs):
22+
return {
23+
'url': url,
24+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
25+
'crop_pct': .9, 'interpolation': 'bicubic',
26+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
27+
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
28+
**kwargs
29+
}
30+
31+
32+
default_cfgs = {
33+
'tnt_s_patch16_224': _cfg(
34+
url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
35+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
36+
),
37+
'tnt_b_patch16_224': _cfg(
38+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
39+
),
40+
}
41+
42+
43+
class Attention(nn.Module):
44+
""" Multi-Head Attention
45+
"""
46+
def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
47+
super().__init__()
48+
self.hidden_dim = hidden_dim
49+
self.num_heads = num_heads
50+
head_dim = hidden_dim // num_heads
51+
self.head_dim = head_dim
52+
self.scale = head_dim ** -0.5
53+
54+
self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias)
55+
self.v = nn.Linear(dim, dim, bias=qkv_bias)
56+
self.attn_drop = nn.Dropout(attn_drop, inplace=True)
57+
self.proj = nn.Linear(dim, dim)
58+
self.proj_drop = nn.Dropout(proj_drop, inplace=True)
59+
60+
def forward(self, x):
61+
B, N, C = x.shape
62+
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
63+
q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple)
64+
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
65+
66+
attn = (q @ k.transpose(-2, -1)) * self.scale
67+
attn = attn.softmax(dim=-1)
68+
attn = self.attn_drop(attn)
69+
70+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
71+
x = self.proj(x)
72+
x = self.proj_drop(x)
73+
return x
74+
75+
76+
class Block(nn.Module):
77+
""" TNT Block
78+
"""
79+
def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4.,
80+
qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
81+
super().__init__()
82+
# Inner transformer
83+
self.norm_in = norm_layer(in_dim)
84+
self.attn_in = Attention(
85+
in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias,
86+
attn_drop=attn_drop, proj_drop=drop)
87+
88+
self.norm_mlp_in = norm_layer(in_dim)
89+
self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4),
90+
out_features=in_dim, act_layer=act_layer, drop=drop)
91+
92+
self.norm1_proj = norm_layer(in_dim)
93+
self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True)
94+
# Outer transformer
95+
self.norm_out = norm_layer(dim)
96+
self.attn_out = Attention(
97+
dim, dim, num_heads=num_heads, qkv_bias=qkv_bias,
98+
attn_drop=attn_drop, proj_drop=drop)
99+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
100+
101+
self.norm_mlp = norm_layer(dim)
102+
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio),
103+
out_features=dim, act_layer=act_layer, drop=drop)
104+
105+
def forward(self, pixel_embed, patch_embed):
106+
# inner
107+
pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed)))
108+
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
109+
# outer
110+
B, N, C = patch_embed.size()
111+
patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))
112+
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
113+
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
114+
return pixel_embed, patch_embed
115+
116+
117+
class PixelEmbed(nn.Module):
118+
""" Image to Pixel Embedding
119+
"""
120+
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
121+
super().__init__()
122+
num_patches = (img_size // patch_size) ** 2
123+
self.img_size = img_size
124+
self.num_patches = num_patches
125+
self.in_dim = in_dim
126+
new_patch_size = math.ceil(patch_size / stride)
127+
self.new_patch_size = new_patch_size
128+
129+
self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
130+
self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
131+
132+
def forward(self, x, pixel_pos):
133+
B, C, H, W = x.shape
134+
assert H == self.img_size and W == self.img_size, \
135+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
136+
x = self.proj(x)
137+
x = self.unfold(x)
138+
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size)
139+
x = x + pixel_pos
140+
x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
141+
return x
142+
143+
144+
class TNT(nn.Module):
145+
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
146+
"""
147+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12,
148+
num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
149+
drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4):
150+
super().__init__()
151+
self.num_classes = num_classes
152+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
153+
154+
self.pixel_embed = PixelEmbed(
155+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride)
156+
num_patches = self.pixel_embed.num_patches
157+
self.num_patches = num_patches
158+
new_patch_size = self.pixel_embed.new_patch_size
159+
num_pixel = new_patch_size ** 2
160+
161+
self.norm1_proj = norm_layer(num_pixel * in_dim)
162+
self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
163+
self.norm2_proj = norm_layer(embed_dim)
164+
165+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
166+
self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
167+
self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size))
168+
self.pos_drop = nn.Dropout(p=drop_rate)
169+
170+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
171+
blocks = []
172+
for i in range(depth):
173+
blocks.append(Block(
174+
dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head,
175+
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate,
176+
drop_path=dpr[i], norm_layer=norm_layer))
177+
self.blocks = nn.ModuleList(blocks)
178+
self.norm = norm_layer(embed_dim)
179+
180+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
181+
182+
trunc_normal_(self.cls_token, std=.02)
183+
trunc_normal_(self.patch_pos, std=.02)
184+
trunc_normal_(self.pixel_pos, std=.02)
185+
self.apply(self._init_weights)
186+
187+
def _init_weights(self, m):
188+
if isinstance(m, nn.Linear):
189+
trunc_normal_(m.weight, std=.02)
190+
if isinstance(m, nn.Linear) and m.bias is not None:
191+
nn.init.constant_(m.bias, 0)
192+
elif isinstance(m, nn.LayerNorm):
193+
nn.init.constant_(m.bias, 0)
194+
nn.init.constant_(m.weight, 1.0)
195+
196+
@torch.jit.ignore
197+
def no_weight_decay(self):
198+
return {'patch_pos', 'pixel_pos', 'cls_token'}
199+
200+
def get_classifier(self):
201+
return self.head
202+
203+
def reset_classifier(self, num_classes, global_pool=''):
204+
self.num_classes = num_classes
205+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
206+
207+
def forward_features(self, x):
208+
B = x.shape[0]
209+
pixel_embed = self.pixel_embed(x, self.pixel_pos)
210+
211+
patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
212+
patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
213+
patch_embed = patch_embed + self.patch_pos
214+
patch_embed = self.pos_drop(patch_embed)
215+
216+
for blk in self.blocks:
217+
pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
218+
219+
patch_embed = self.norm(patch_embed)
220+
return patch_embed[:, 0]
221+
222+
def forward(self, x):
223+
x = self.forward_features(x)
224+
x = self.head(x)
225+
return x
226+
227+
228+
@register_model
229+
def tnt_s_patch16_224(pretrained=False, **kwargs):
230+
model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
231+
qkv_bias=False, **kwargs)
232+
model.default_cfg = default_cfgs['tnt_s_patch16_224']
233+
if pretrained:
234+
load_pretrained(
235+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
236+
return model
237+
238+
239+
@register_model
240+
def tnt_b_patch16_224(pretrained=False, **kwargs):
241+
model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
242+
qkv_bias=False, **kwargs)
243+
model.default_cfg = default_cfgs['tnt_b_patch16_224']
244+
if pretrained:
245+
load_pretrained(
246+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
247+
return model

0 commit comments

Comments
 (0)