-
-
Notifications
You must be signed in to change notification settings - Fork 5k
Description
Hi!
Thank you for the excellent library and the recent addition of NaFlex transformers. I've been learning a lot from exploring it! While experimenting with the code, I identified a few points I'd like to discuss regarding the two key components of NaFlex compared to standard transformers: 1) on-the-fly patch embedding interpolation and 2) on-the-fly positional embedding interpolation.
Patch Embeddings
The interpolation of patch embeddings uses a bit more complicated method involving the calculation of a pseudo-inverse matrix. Since it operates on a single matrix (with all patch sizes within a batch being identical) of shape (C_out, C_in, P_h, P_w)
, its performance impact is relatively minor. However, I noticed a small optimization that could improve efficiency. The current implementation uses torch.vmap
to vectorize calculations across channels:
def _apply_resampling(
patch_embed: torch.Tensor,
pinv_matrix: torch.Tensor,
new_size_tuple: Tuple[int, int],
orig_dtype: torch.dtype,
intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE
) -> torch.Tensor:
"""Applies the precomputed pinv_matrix to resample the patch_embed tensor."""
try:
from torch import vmap
except ImportError:
from functorch import vmap
def resample_kernel(kernel: torch.Tensor) -> torch.Tensor:
kernel_flat = kernel.reshape(-1).to(intermediate_dtype)
resampled_kernel_flat = pinv_matrix @ kernel_flat
return resampled_kernel_flat.reshape(new_size_tuple)
resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0)
patch_embed_float = patch_embed.to(intermediate_dtype)
resampled_patch_embed = resample_kernel_vmap(patch_embed_float)
return resampled_patch_embed.to(orig_dtype)
I was unfamiliar with vmap
(which seems more common in JAX-style workflows), and its use here was initially unclear. While it performs well, the operation is essentially a batched matrix multiplication, which can be simplified using the standard @
operator. This alternative approach reduces code complexity and slightly improves performance:
def _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype, intermediate_dtype):
c_out, c_in, *_ = patch_embed.shape
patch_embed = patch_embed.reshape(c_out, c_in, -1).to(dtype=intermediate_dtype)
pinv_matrix = pinv_matrix.to(dtype=intermediate_dtype)
resampled_patch_embed = patch_embed @ pinv_matrix.T # (C_out, C_in, P_old * P_old) @ (P_old * P_old, P_new * P_new)
resampled_patch_embed = resampled_patch_embed.reshape(c_out, c_in, *new_size_tuple).to(dtype=orig_dtype)
return resampled_patch_embed
In my tests, resampling a (768, 3, 32, 32)
embedding to (768, 3, 16, 16)
took 18 ms (forward) and 26 ms (forward + backward) with the vmap
implementation. The batched matrix multiplication approach reduced this to 11 ms (forward) and 13 ms (forward + backward), roughly doubling the speed for the forward + backward pass while using less code.
Positional Embeddings
Unlike patch embeddings, where all images in a batch share the same patch size, positional embeddings must account for varying aspect ratios and dimensions across images in a batch. This requires calling F.interpolate
individually for each sample, currently implemented in a Python for-loop, which significantly impacts performance. While there's an optimization to group similarly sized images and interpolate once per group, this only helps when many images share the same size. In the original JAX implementation, vectorization with vmap
improved efficiency, but PyTorch's F.interpolate
is incompatible with torch.vmap
, which seems like pytorch's current limitation.
I tested performance with the following code:
import einx
import torch
from timm.models import NaFlexEmbeds
from torch.nn.utils.rnn import pad_sequence
def get_coords(h, w):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
coord = einx.rearrange('i j, i j -> (i j) (1 + 1)', y, x)
return coord
# Generate a batch of 512 images with random heights and widths (up to 32 patches max in one dimension)
device = 'cuda'
B = 512
sizes = torch.randint(4, 32, (B, 2))
max_seq_len = sizes.prod(-1).amax()
coords = pad_sequence([get_coords(h, w) for h, w in sizes], batch_first=True).to(device=device)
x = torch.randn(B, coords.shape[1], 16 * 16 * 3).to(device=device)
patch_emb = NaFlexEmbeds(embed_dim=768, proj_type='linear', pos_embed_grid_size=(32, 32)).cuda()
%%time
patch_emb(x, coords).sum().backward()
torch.cuda.synchronize()
Batch Size | Forward (ms) | Forward + Backward (ms) |
---|---|---|
256 | 240 | 1940 |
512 | 370 | 7230 |
While the forward pass is reasonably fast, the backward pass is notably slow and scales super-linearly with batch size for some reason.
One idea, though it will produce slightly different results, as there is no antialias=True
equivalent, is to use torch.grid_sample
instead of F.interpolate
, which actually supports resampling different images in batch to different shapes in vectorized manner.
An example implementation using it might look like this:
def apply_naflex_pos_emb_grid_sample(pos_emb, x, coords, mode='bilinear', align_corners=False, padding_mode='zeros'):
B = x.shape[0]
C = pos_emb.shape[-1]
pos_emb = einx.rearrange('h w c -> b c h w', pos_emb, b=B)
shapes = coords.amax(1) + 1
grid_size = shapes.amax(0)
theta = torch.zeros(B, 2, 3, dtype=pos_emb.dtype, device=pos_emb.device)
theta[:, 0, 0] = grid_size[1] / shapes[:, 1] # scale x
theta[:, 0, 2] = theta[:, 0, 0] - 1 # translate x
theta[:, 1, 1] = grid_size[0] / shapes[:, 0] # scale y
theta[:, 1, 2] = theta[:, 1, 1] - 1 # translate y
grid = F.affine_grid(theta, (B, C, *grid_size), align_corners=align_corners)
pos_emb = F.grid_sample(pos_emb, grid, mode=mode, align_corners=align_corners, padding_mode=padding_mode)
bi = einx.rearrange('b -> b n', torch.arange(B, device=pos_emb.device), n=coords.shape[1])
x = x + pos_emb[bi, :, coords[..., 0], coords[..., 1]]
return x
This approach eliminates the Python for-loop, reducing the forward + backward pass for a batch size of 512 to just 150 ms. I haven't tested its impact on training performance compared to F.interpolate
with antialias=True
vs antialias=False
, but for bilinear interpolation, F.grid_sample
produces numerically similar results to F.interpolate
with antialias=False
.