-
-
Notifications
You must be signed in to change notification settings - Fork 5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When the fused_attn
is used, the scale
of the attention is not specified in torch.nn.functional.scaled_dot_product_attention
and the value defaults to q.size(-1) ** -0.5
, which is different from the default from the Attention2d
layer (num_heads ** -0.5
).
This means that the results from the fused implementation and the vanilla one are different.
pytorch-image-models/timm/layers/attention2d.py
Lines 294 to 351 in dafe866
class Attention2d(nn.Module): | |
fused_attn: torch.jit.Final[bool] | |
""" multi-head attention for 2D NCHW tensors""" | |
def __init__( | |
self, | |
dim: int, | |
dim_out: Optional[int] = None, | |
num_heads: int = 32, | |
bias: bool = True, | |
expand_first: bool = False, | |
head_first: bool = False, | |
attn_drop: float = 0., | |
proj_drop: float = 0. | |
): | |
super().__init__() | |
dim_out = dim_out or dim | |
dim_attn = dim_out if expand_first else dim | |
self.num_heads = num_heads | |
self.dim_head = dim_attn // num_heads | |
self.head_first = head_first | |
self.scale = num_heads ** -0.5 | |
self.fused_attn = use_fused_attn() | |
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x, attn_mask: Optional[torch.Tensor] = None): | |
B, C, H, W = x.shape | |
if self.head_first: | |
q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) | |
else: | |
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) | |
if self.fused_attn: | |
x = torch.nn.functional.scaled_dot_product_attention( | |
q.transpose(-1, -2).contiguous(), | |
k.transpose(-1, -2).contiguous(), | |
v.transpose(-1, -2).contiguous(), | |
attn_mask=attn_mask, | |
dropout_p=self.attn_drop.p if self.training else 0., | |
).transpose(-1, -2).reshape(B, -1, H, W) | |
else: | |
q = q * self.scale | |
attn = q.transpose(-2, -1) @ k | |
if attn_mask is not None: | |
# NOTE: assumes mask is float and in correct shape | |
attn = attn + attn_mask | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x |
Expected behavior
Same results for the two implementations.
Desktop (please complete the following information):
- OS: macOS
- This repository version: 1.0.12
- PyTorch version 2.5 (CPU)
Additional context
Add any other context about the problem here.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working