|
| 1 | +""" Transformer in Transformer (TNT) in PyTorch |
| 2 | +
|
| 3 | +A PyTorch implement of TNT as described in |
| 4 | +'Transformer in Transformer' - https://arxiv.org/abs/2103.00112 |
| 5 | +
|
| 6 | +The official mindspore code is released and available at |
| 7 | +https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT |
| 8 | +""" |
| 9 | +import math |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | +from functools import partial |
| 13 | + |
| 14 | +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| 15 | +from timm.models.helpers import load_pretrained |
| 16 | +from timm.models.layers import DropPath, trunc_normal_ |
| 17 | +from timm.models.vision_transformer import Mlp |
| 18 | +from timm.models.registry import register_model |
| 19 | + |
| 20 | + |
| 21 | +def _cfg(url='', **kwargs): |
| 22 | + return { |
| 23 | + 'url': url, |
| 24 | + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, |
| 25 | + 'crop_pct': .9, 'interpolation': 'bicubic', |
| 26 | + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, |
| 27 | + 'first_conv': 'pixel_embed.proj', 'classifier': 'head', |
| 28 | + **kwargs |
| 29 | + } |
| 30 | + |
| 31 | + |
| 32 | +default_cfgs = { |
| 33 | + 'tnt_s_patch16_224': _cfg( |
| 34 | + url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', |
| 35 | + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), |
| 36 | + ), |
| 37 | + 'tnt_b_patch16_224': _cfg( |
| 38 | + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), |
| 39 | + ), |
| 40 | +} |
| 41 | + |
| 42 | + |
| 43 | +class Attention(nn.Module): |
| 44 | + """ Multi-Head Attention |
| 45 | + """ |
| 46 | + def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): |
| 47 | + super().__init__() |
| 48 | + self.hidden_dim = hidden_dim |
| 49 | + self.num_heads = num_heads |
| 50 | + head_dim = hidden_dim // num_heads |
| 51 | + self.head_dim = head_dim |
| 52 | + self.scale = head_dim ** -0.5 |
| 53 | + |
| 54 | + self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias) |
| 55 | + self.v = nn.Linear(dim, dim, bias=qkv_bias) |
| 56 | + self.attn_drop = nn.Dropout(attn_drop, inplace=True) |
| 57 | + self.proj = nn.Linear(dim, dim) |
| 58 | + self.proj_drop = nn.Dropout(proj_drop, inplace=True) |
| 59 | + |
| 60 | + def forward(self, x): |
| 61 | + B, N, C = x.shape |
| 62 | + qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| 63 | + q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple) |
| 64 | + v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) |
| 65 | + |
| 66 | + attn = (q @ k.transpose(-2, -1)) * self.scale |
| 67 | + attn = attn.softmax(dim=-1) |
| 68 | + attn = self.attn_drop(attn) |
| 69 | + |
| 70 | + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
| 71 | + x = self.proj(x) |
| 72 | + x = self.proj_drop(x) |
| 73 | + return x |
| 74 | + |
| 75 | + |
| 76 | +class Block(nn.Module): |
| 77 | + """ TNT Block |
| 78 | + """ |
| 79 | + def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., |
| 80 | + qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
| 81 | + super().__init__() |
| 82 | + # Inner transformer |
| 83 | + self.norm_in = norm_layer(in_dim) |
| 84 | + self.attn_in = Attention( |
| 85 | + in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, |
| 86 | + attn_drop=attn_drop, proj_drop=drop) |
| 87 | + |
| 88 | + self.norm_mlp_in = norm_layer(in_dim) |
| 89 | + self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), |
| 90 | + out_features=in_dim, act_layer=act_layer, drop=drop) |
| 91 | + |
| 92 | + self.norm1_proj = norm_layer(in_dim) |
| 93 | + self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) |
| 94 | + # Outer transformer |
| 95 | + self.norm_out = norm_layer(dim) |
| 96 | + self.attn_out = Attention( |
| 97 | + dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, |
| 98 | + attn_drop=attn_drop, proj_drop=drop) |
| 99 | + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| 100 | + |
| 101 | + self.norm_mlp = norm_layer(dim) |
| 102 | + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), |
| 103 | + out_features=dim, act_layer=act_layer, drop=drop) |
| 104 | + |
| 105 | + def forward(self, pixel_embed, patch_embed): |
| 106 | + # inner |
| 107 | + pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed))) |
| 108 | + pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) |
| 109 | + # outer |
| 110 | + B, N, C = patch_embed.size() |
| 111 | + patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) |
| 112 | + patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) |
| 113 | + patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) |
| 114 | + return pixel_embed, patch_embed |
| 115 | + |
| 116 | + |
| 117 | +class PixelEmbed(nn.Module): |
| 118 | + """ Image to Pixel Embedding |
| 119 | + """ |
| 120 | + def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): |
| 121 | + super().__init__() |
| 122 | + num_patches = (img_size // patch_size) ** 2 |
| 123 | + self.img_size = img_size |
| 124 | + self.num_patches = num_patches |
| 125 | + self.in_dim = in_dim |
| 126 | + new_patch_size = math.ceil(patch_size / stride) |
| 127 | + self.new_patch_size = new_patch_size |
| 128 | + |
| 129 | + self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) |
| 130 | + self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) |
| 131 | + |
| 132 | + def forward(self, x, pixel_pos): |
| 133 | + B, C, H, W = x.shape |
| 134 | + assert H == self.img_size and W == self.img_size, \ |
| 135 | + f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})." |
| 136 | + x = self.proj(x) |
| 137 | + x = self.unfold(x) |
| 138 | + x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size) |
| 139 | + x = x + pixel_pos |
| 140 | + x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) |
| 141 | + return x |
| 142 | + |
| 143 | + |
| 144 | +class TNT(nn.Module): |
| 145 | + """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 |
| 146 | + """ |
| 147 | + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, |
| 148 | + num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., |
| 149 | + drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): |
| 150 | + super().__init__() |
| 151 | + self.num_classes = num_classes |
| 152 | + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models |
| 153 | + |
| 154 | + self.pixel_embed = PixelEmbed( |
| 155 | + img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride) |
| 156 | + num_patches = self.pixel_embed.num_patches |
| 157 | + self.num_patches = num_patches |
| 158 | + new_patch_size = self.pixel_embed.new_patch_size |
| 159 | + num_pixel = new_patch_size ** 2 |
| 160 | + |
| 161 | + self.norm1_proj = norm_layer(num_pixel * in_dim) |
| 162 | + self.proj = nn.Linear(num_pixel * in_dim, embed_dim) |
| 163 | + self.norm2_proj = norm_layer(embed_dim) |
| 164 | + |
| 165 | + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| 166 | + self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
| 167 | + self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size)) |
| 168 | + self.pos_drop = nn.Dropout(p=drop_rate) |
| 169 | + |
| 170 | + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule |
| 171 | + blocks = [] |
| 172 | + for i in range(depth): |
| 173 | + blocks.append(Block( |
| 174 | + dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, |
| 175 | + mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, |
| 176 | + drop_path=dpr[i], norm_layer=norm_layer)) |
| 177 | + self.blocks = nn.ModuleList(blocks) |
| 178 | + self.norm = norm_layer(embed_dim) |
| 179 | + |
| 180 | + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
| 181 | + |
| 182 | + trunc_normal_(self.cls_token, std=.02) |
| 183 | + trunc_normal_(self.patch_pos, std=.02) |
| 184 | + trunc_normal_(self.pixel_pos, std=.02) |
| 185 | + self.apply(self._init_weights) |
| 186 | + |
| 187 | + def _init_weights(self, m): |
| 188 | + if isinstance(m, nn.Linear): |
| 189 | + trunc_normal_(m.weight, std=.02) |
| 190 | + if isinstance(m, nn.Linear) and m.bias is not None: |
| 191 | + nn.init.constant_(m.bias, 0) |
| 192 | + elif isinstance(m, nn.LayerNorm): |
| 193 | + nn.init.constant_(m.bias, 0) |
| 194 | + nn.init.constant_(m.weight, 1.0) |
| 195 | + |
| 196 | + @torch.jit.ignore |
| 197 | + def no_weight_decay(self): |
| 198 | + return {'patch_pos', 'pixel_pos', 'cls_token'} |
| 199 | + |
| 200 | + def get_classifier(self): |
| 201 | + return self.head |
| 202 | + |
| 203 | + def reset_classifier(self, num_classes, global_pool=''): |
| 204 | + self.num_classes = num_classes |
| 205 | + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
| 206 | + |
| 207 | + def forward_features(self, x): |
| 208 | + B = x.shape[0] |
| 209 | + pixel_embed = self.pixel_embed(x, self.pixel_pos) |
| 210 | + |
| 211 | + patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) |
| 212 | + patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) |
| 213 | + patch_embed = patch_embed + self.patch_pos |
| 214 | + patch_embed = self.pos_drop(patch_embed) |
| 215 | + |
| 216 | + for blk in self.blocks: |
| 217 | + pixel_embed, patch_embed = blk(pixel_embed, patch_embed) |
| 218 | + |
| 219 | + patch_embed = self.norm(patch_embed) |
| 220 | + return patch_embed[:, 0] |
| 221 | + |
| 222 | + def forward(self, x): |
| 223 | + x = self.forward_features(x) |
| 224 | + x = self.head(x) |
| 225 | + return x |
| 226 | + |
| 227 | + |
| 228 | +@register_model |
| 229 | +def tnt_s_patch16_224(pretrained=False, **kwargs): |
| 230 | + model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, |
| 231 | + qkv_bias=False, **kwargs) |
| 232 | + model.default_cfg = default_cfgs['tnt_s_patch16_224'] |
| 233 | + if pretrained: |
| 234 | + load_pretrained( |
| 235 | + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) |
| 236 | + return model |
| 237 | + |
| 238 | + |
| 239 | +@register_model |
| 240 | +def tnt_b_patch16_224(pretrained=False, **kwargs): |
| 241 | + model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, |
| 242 | + qkv_bias=False, **kwargs) |
| 243 | + model.default_cfg = default_cfgs['tnt_b_patch16_224'] |
| 244 | + if pretrained: |
| 245 | + load_pretrained( |
| 246 | + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) |
| 247 | + return model |
0 commit comments