Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions timm/layers/attention_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
pos_embed: str = '',
pool_type: str = 'token',
norm_layer: Optional[nn.Module] = None,
act_layer: Optional[nn.Module] = nn.GELU,
drop: float = 0.0,
):
super().__init__()
Expand All @@ -54,13 +55,18 @@ def __init__(

self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
if qk_norm:
qk_norm_layer = norm_layer or nn.LayerNorm
self.q_norm = qk_norm_layer(self.head_dim)
self.k_norm = qk_norm_layer(self.head_dim)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(drop)

self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer)

self.init_weights()

Expand Down
Loading