Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 12 additions & 26 deletions timm/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,47 +270,33 @@ def _compute_resize_matrix(

eye_matrix = torch.eye(old_total, device=device, dtype=dtype)
basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w)

resized_basis_vectors_batch = F.interpolate(
basis_vectors_batch,
size=new_size,
mode=interpolation,
antialias=antialias,
align_corners=False
) # Output shape: (old_total, 1, new_h, new_w)

resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T
resize_matrix = resized_basis_vectors_batch.squeeze(1).permute(1, 2, 0).reshape(new_total, old_total)
return resize_matrix # Shape: (new_total, old_total)


def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor:
"""Calculates the pseudoinverse matrix used for the resampling operation."""
pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total)
return pinv_matrix


def _apply_resampling(
patch_embed: torch.Tensor,
pinv_matrix: torch.Tensor,
new_size_tuple: Tuple[int, int],
orig_dtype: torch.dtype,
intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE
) -> torch.Tensor:
"""Applies the precomputed pinv_matrix to resample the patch_embed tensor."""
try:
from torch import vmap
except ImportError:
from functorch import vmap

def resample_kernel(kernel: torch.Tensor) -> torch.Tensor:
kernel_flat = kernel.reshape(-1).to(intermediate_dtype)
resampled_kernel_flat = pinv_matrix @ kernel_flat
return resampled_kernel_flat.reshape(new_size_tuple)

resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0)
patch_embed_float = patch_embed.to(intermediate_dtype)
resampled_patch_embed = resample_kernel_vmap(patch_embed_float)
return resampled_patch_embed.to(orig_dtype)
""" Simplified resampling w/o vmap use.
As proposed by https://github.com/stas-sl
"""
c_out, c_in, *_ = patch_embed.shape
patch_embed = patch_embed.reshape(c_out, c_in, -1).to(dtype=intermediate_dtype)
pinv_matrix = pinv_matrix.to(dtype=intermediate_dtype)
resampled_patch_embed = patch_embed @ pinv_matrix # (C_out, C_in, P_old * P_old) @ (P_old * P_old, P_new * P_new)
resampled_patch_embed = resampled_patch_embed.reshape(c_out, c_in, *new_size_tuple).to(dtype=orig_dtype)
return resampled_patch_embed


def resample_patch_embed(
Expand All @@ -336,7 +322,7 @@ def resample_patch_embed(
resize_mat = _compute_resize_matrix(
old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE
)
pinv_matrix = _compute_pinv_for_resampling(resize_mat)
pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling
resampled_patch_embed = _apply_resampling(
patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE
)
Expand Down Expand Up @@ -388,7 +374,7 @@ def _get_or_create_pinv_matrix(
resize_mat = _compute_resize_matrix(
self.orig_size, new_size, self.interpolation, self.antialias, device, dtype
)
pinv_matrix = _compute_pinv_for_resampling(resize_mat)
pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling

# Cache using register_buffer
buffer_name = f"pinv_{new_size[0]}x{new_size[1]}"
Expand Down
Loading