Skip to content

Naflex performance #2514

@stas-sl

Description

@stas-sl

Hi!

Thank you for the excellent library and the recent addition of NaFlex transformers. I've been learning a lot from exploring it! While experimenting with the code, I identified a few points I'd like to discuss regarding the two key components of NaFlex compared to standard transformers: 1) on-the-fly patch embedding interpolation and 2) on-the-fly positional embedding interpolation.

Patch Embeddings

The interpolation of patch embeddings uses a bit more complicated method involving the calculation of a pseudo-inverse matrix. Since it operates on a single matrix (with all patch sizes within a batch being identical) of shape (C_out, C_in, P_h, P_w), its performance impact is relatively minor. However, I noticed a small optimization that could improve efficiency. The current implementation uses torch.vmap to vectorize calculations across channels:

def _apply_resampling(
    patch_embed: torch.Tensor,
    pinv_matrix: torch.Tensor,
    new_size_tuple: Tuple[int, int],
    orig_dtype: torch.dtype,
    intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE
) -> torch.Tensor:
    """Applies the precomputed pinv_matrix to resample the patch_embed tensor."""
    try:
        from torch import vmap
    except ImportError:
        from functorch import vmap

    def resample_kernel(kernel: torch.Tensor) -> torch.Tensor:
        kernel_flat = kernel.reshape(-1).to(intermediate_dtype)
        resampled_kernel_flat = pinv_matrix @ kernel_flat
        return resampled_kernel_flat.reshape(new_size_tuple)

    resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0)
    patch_embed_float = patch_embed.to(intermediate_dtype)
    resampled_patch_embed = resample_kernel_vmap(patch_embed_float)
    return resampled_patch_embed.to(orig_dtype)

I was unfamiliar with vmap (which seems more common in JAX-style workflows), and its use here was initially unclear. While it performs well, the operation is essentially a batched matrix multiplication, which can be simplified using the standard @ operator. This alternative approach reduces code complexity and slightly improves performance:

def _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype, intermediate_dtype): 
    c_out, c_in, *_ = patch_embed.shape
    patch_embed = patch_embed.reshape(c_out, c_in, -1).to(dtype=intermediate_dtype)
    pinv_matrix = pinv_matrix.to(dtype=intermediate_dtype)
    resampled_patch_embed = patch_embed @ pinv_matrix.T  # (C_out, C_in, P_old * P_old) @ (P_old * P_old, P_new * P_new)
    resampled_patch_embed = resampled_patch_embed.reshape(c_out, c_in, *new_size_tuple).to(dtype=orig_dtype)
    return resampled_patch_embed

In my tests, resampling a (768, 3, 32, 32) embedding to (768, 3, 16, 16) took 18 ms (forward) and 26 ms (forward + backward) with the vmap implementation. The batched matrix multiplication approach reduced this to 11 ms (forward) and 13 ms (forward + backward), roughly doubling the speed for the forward + backward pass while using less code.

Positional Embeddings

Unlike patch embeddings, where all images in a batch share the same patch size, positional embeddings must account for varying aspect ratios and dimensions across images in a batch. This requires calling F.interpolate individually for each sample, currently implemented in a Python for-loop, which significantly impacts performance. While there's an optimization to group similarly sized images and interpolate once per group, this only helps when many images share the same size. In the original JAX implementation, vectorization with vmap improved efficiency, but PyTorch's F.interpolate is incompatible with torch.vmap, which seems like pytorch's current limitation.

I tested performance with the following code:

import einx
import torch
from timm.models import NaFlexEmbeds
from torch.nn.utils.rnn import pad_sequence

def get_coords(h, w):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
    coord = einx.rearrange('i j, i j -> (i j) (1 + 1)', y, x)
    return coord

# Generate a batch of 512 images with random heights and widths (up to 32 patches max in one dimension)
device = 'cuda'
B = 512
sizes = torch.randint(4, 32, (B, 2))
max_seq_len = sizes.prod(-1).amax()
coords = pad_sequence([get_coords(h, w) for h, w in sizes], batch_first=True).to(device=device)
x = torch.randn(B, coords.shape[1], 16 * 16 * 3).to(device=device)

patch_emb = NaFlexEmbeds(embed_dim=768, proj_type='linear', pos_embed_grid_size=(32, 32)).cuda()
%%time 
patch_emb(x, coords).sum().backward()
torch.cuda.synchronize()
Batch Size Forward (ms) Forward + Backward (ms)
256 240 1940
512 370 7230

While the forward pass is reasonably fast, the backward pass is notably slow and scales super-linearly with batch size for some reason.

One idea, though it will produce slightly different results, as there is no antialias=True equivalent, is to use torch.grid_sample instead of F.interpolate, which actually supports resampling different images in batch to different shapes in vectorized manner.

An example implementation using it might look like this:

def apply_naflex_pos_emb_grid_sample(pos_emb, x, coords, mode='bilinear', align_corners=False, padding_mode='zeros'):
    B = x.shape[0]
    C = pos_emb.shape[-1]
    pos_emb = einx.rearrange('h w c -> b c h w', pos_emb, b=B)
    shapes = coords.amax(1) + 1
    grid_size = shapes.amax(0)
    theta = torch.zeros(B, 2, 3, dtype=pos_emb.dtype, device=pos_emb.device)
    theta[:, 0, 0] = grid_size[1] / shapes[:, 1]  # scale x
    theta[:, 0, 2] = theta[:, 0, 0] - 1           # translate x
    theta[:, 1, 1] = grid_size[0] / shapes[:, 0]  # scale y
    theta[:, 1, 2] = theta[:, 1, 1] - 1           # translate y
    grid = F.affine_grid(theta, (B, C, *grid_size), align_corners=align_corners)
    pos_emb = F.grid_sample(pos_emb, grid, mode=mode, align_corners=align_corners, padding_mode=padding_mode)
    bi = einx.rearrange('b -> b n', torch.arange(B, device=pos_emb.device), n=coords.shape[1])
    x = x + pos_emb[bi, :, coords[..., 0], coords[..., 1]]
    return x

This approach eliminates the Python for-loop, reducing the forward + backward pass for a batch size of 512 to just 150 ms. I haven't tested its impact on training performance compared to F.interpolate with antialias=True vs antialias=False, but for bilinear interpolation, F.grid_sample produces numerically similar results to F.interpolate with antialias=False.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions