Skip to content
Open
77 changes: 37 additions & 40 deletions src/anomalib/models/components/sampling/k_center_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

import torch
from torch.nn import functional as F # noqa: N812
from tqdm import tqdm

from anomalib.models.components.dimensionality_reduction import SparseRandomProjection
Expand Down Expand Up @@ -58,84 +57,82 @@ def reset_distances(self) -> None:
"""Reset minimum distances to None."""
self.min_distances = None

def update_distances(self, cluster_centers: list[int]) -> None:
"""Update minimum distances given cluster centers.
def update_distances(self, cluster_centers: int | torch.Tensor | None) -> None:
"""Update minimum distances given a single cluster center.

Args:
cluster_centers (list[int]): Indices of cluster centers.
cluster_centers (int | torch.Tensor | None): Index of a single cluster center.
Can be an int, a 0-d tensor, or a 1-d tensor with shape [1].

Note:
This method is optimized for single-center updates. Passing multiple
indices may result in incorrect behavior or runtime errors.
"""
if cluster_centers:
if cluster_centers is not None:
centers = self.features[cluster_centers]

distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1)
# Ensure centers is a 1-d tensor for broadcasting
centers = centers.squeeze(0) if centers.dim() == 2 else centers
Copy link
Preview

Copilot AI Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition centers.dim() == 2 assumes centers will only have 2 dimensions when it needs squeezing, but if cluster_centers is a 1-d tensor with shape [1], centers could already be 1-d and not need squeezing. This could cause issues if centers has other dimensions. Consider using centers = centers.squeeze() to remove all dimensions of size 1, or add more specific dimension checks.

Suggested change
centers = centers.squeeze(0) if centers.dim() == 2 else centers
centers = centers.squeeze()

Copilot uses AI. Check for mistakes.

# Using torch.linalg.norm is faster than torch.Functional.pairwise_distance on
# both CUDA and Intel-based hardware.
distances = torch.linalg.norm(self.features - centers, ord=2, dim=1, keepdim=True)
# Synchronize on XPU to ensure accurate progress bar display.
if distances.device.type == "xpu":
torch.xpu.synchronize()

if self.min_distances is None:
self.min_distances = distance
self.min_distances = distances
else:
self.min_distances = torch.minimum(self.min_distances, distance)
self.min_distances = torch.minimum(self.min_distances, distances)

def get_new_idx(self) -> int:
def get_new_idx(self) -> torch.Tensor:
"""Get index of the next sample based on maximum minimum distance.

Returns:
int: Index of the selected sample.
torch.Tensor: Index of the selected sample (tensor, not converted to int).

Raises:
TypeError: If `self.min_distances` is not a torch.Tensor.
"""
if isinstance(self.min_distances, torch.Tensor):
idx = int(torch.argmax(self.min_distances).item())
_, idx = torch.max(self.min_distances.squeeze(), dim=0)
Copy link
Preview

Copilot AI Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using squeeze() without arguments removes all dimensions of size 1, which could cause issues if min_distances has multiple dimensions that shouldn't be squeezed. Since the method expects a column vector (from keepdim=True in line 78), consider using squeeze(1) to only remove the second dimension, or verify the tensor shape before squeezing.

Suggested change
_, idx = torch.max(self.min_distances.squeeze(), dim=0)
_, idx = torch.max(self.min_distances.squeeze(1), dim=0)

Copilot uses AI. Check for mistakes.

else:
msg = f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}"
raise TypeError(msg)

return idx

def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[int]:
"""Greedily form a coreset to minimize maximum distance to cluster centers.
def select_coreset_idxs(self) -> list[int]:
"""Greedily select coreset indices to minimize maximum distance to centers.

Args:
selected_idxs (list[int] | None, optional): Indices of pre-selected
samples. Defaults to None.
The algorithm iteratively selects points that are farthest from the already
selected centers, starting from a random initial point.

Returns:
list[int]: Indices of samples selected to minimize distance to cluster
centers.

Raises:
ValueError: If a newly selected index is already in `selected_idxs`.
list[int]: Indices of samples selected to form the coreset.
"""
if selected_idxs is None:
selected_idxs = []

if self.embedding.ndim == 2:
self.model.fit(self.embedding)
self.features = self.model.transform(self.embedding)
self.reset_distances()
else:
self.features = self.embedding.reshape(self.embedding.shape[0], -1)
self.update_distances(cluster_centers=selected_idxs)
self.reset_distances()

# random starting point
idx = torch.randint(high=self.n_observations, size=(1,), device=self.features.device).squeeze()

selected_coreset_idxs: list[int] = []
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
selected_coreset_idxs: list[torch.Tensor] = []
for _ in tqdm(range(self.coreset_size), desc="Selecting Coreset Indices."):
self.update_distances(cluster_centers=[idx])
self.update_distances(cluster_centers=idx.unsqueeze(0))
Copy link
Preview

Copilot AI Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method signature accepts int | torch.Tensor | None but here you're always passing a tensor with an extra dimension. This inconsistency with the parameter handling in update_distances could lead to confusion. Consider either updating the method to handle the tensor directly or consistently passing the same type.

Suggested change
self.update_distances(cluster_centers=idx.unsqueeze(0))
self.update_distances(cluster_centers=idx)

Copilot uses AI. Check for mistakes.

idx = self.get_new_idx()
if idx in selected_idxs:
msg = "New indices should not be in selected indices."
raise ValueError(msg)
self.min_distances[idx] = 0
self.min_distances.scatter_(0, idx.unsqueeze(0).unsqueeze(1), 0.0)
selected_coreset_idxs.append(idx)

return selected_coreset_idxs
return [int(tensor_idx.item()) for tensor_idx in selected_coreset_idxs]
Comment on lines +124 to +131
Copy link
Preview

Copilot AI Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation indicates a list of tensors, but the final return statement converts these to integers. Consider using list[int] as the type annotation and converting tensors to integers immediately when appending, or keep tensors throughout and convert only at return.

Copilot uses AI. Check for mistakes.


def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor:
def sample_coreset(self) -> torch.Tensor:
"""Select coreset from the embedding.

Args:
selected_idxs (list[int] | None, optional): Indices of pre-selected
samples. Defaults to None.

Returns:
torch.Tensor: Selected coreset.

Expand All @@ -147,5 +144,5 @@ def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor
>>> coreset.shape
torch.Size([219, 1536])
"""
idxs = self.select_coreset_idxs(selected_idxs)
idxs = self.select_coreset_idxs()
return self.embedding[idxs]