diff --git a/src/anomalib/models/components/sampling/k_center_greedy.py b/src/anomalib/models/components/sampling/k_center_greedy.py index ac9558c805..23df7a0b38 100644 --- a/src/anomalib/models/components/sampling/k_center_greedy.py +++ b/src/anomalib/models/components/sampling/k_center_greedy.py @@ -10,7 +10,6 @@ """ import torch -from torch.nn import functional as F # noqa: N812 from tqdm import tqdm from anomalib.models.components.dimensionality_reduction import SparseRandomProjection @@ -58,84 +57,82 @@ def reset_distances(self) -> None: """Reset minimum distances to None.""" self.min_distances = None - def update_distances(self, cluster_centers: list[int]) -> None: - """Update minimum distances given cluster centers. + def update_distances(self, cluster_center: int | torch.Tensor | None) -> None: + """Update minimum distances given a single cluster center. Args: - cluster_centers (list[int]): Indices of cluster centers. - """ - if cluster_centers: - centers = self.features[cluster_centers] + cluster_center (int | torch.Tensor | None): Index of a single cluster center. + Can be an int, a 0-d tensor, or a 1-d tensor with shape [1]. - distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1) + Note: + This method is optimized for single-center updates. Passing multiple + indices may result in incorrect behavior or runtime errors. + """ + if cluster_center is not None: + center = self.features[cluster_center] + + # Ensure center is a 1-d tensor for broadcasting + center = center.squeeze() + # Using torch.linalg.norm is faster than torch.Functional.pairwise_distance on + # both CUDA and Intel-based hardware. + distances = torch.linalg.norm(self.features - center, ord=2, dim=1, keepdim=True) + # Synchronize on XPU to ensure accurate progress bar display. + if distances.device.type == "xpu": + torch.xpu.synchronize() if self.min_distances is None: - self.min_distances = distance + self.min_distances = distances else: - self.min_distances = torch.minimum(self.min_distances, distance) + self.min_distances = torch.minimum(self.min_distances, distances) - def get_new_idx(self) -> int: + def get_new_idx(self) -> torch.Tensor: """Get index of the next sample based on maximum minimum distance. Returns: - int: Index of the selected sample. + torch.Tensor: Index of the selected sample (tensor, not converted to int). Raises: TypeError: If `self.min_distances` is not a torch.Tensor. """ if isinstance(self.min_distances, torch.Tensor): - idx = int(torch.argmax(self.min_distances).item()) + _, idx = torch.max(self.min_distances.squeeze(1), dim=0) else: msg = f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}" raise TypeError(msg) - return idx - def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[int]: - """Greedily form a coreset to minimize maximum distance to cluster centers. + def select_coreset_idxs(self) -> list[int]: + """Greedily select coreset indices to minimize maximum distance to centers. - Args: - selected_idxs (list[int] | None, optional): Indices of pre-selected - samples. Defaults to None. + The algorithm iteratively selects points that are farthest from the already + selected centers, starting from a random initial point. Returns: - list[int]: Indices of samples selected to minimize distance to cluster - centers. - - Raises: - ValueError: If a newly selected index is already in `selected_idxs`. + list[int]: Indices of samples selected to form the coreset. """ - if selected_idxs is None: - selected_idxs = [] - if self.embedding.ndim == 2: self.model.fit(self.embedding) self.features = self.model.transform(self.embedding) self.reset_distances() else: self.features = self.embedding.reshape(self.embedding.shape[0], -1) - self.update_distances(cluster_centers=selected_idxs) + self.reset_distances() + + # random starting point + idx = torch.randint(high=self.n_observations, size=(1,), device=self.features.device).squeeze() selected_coreset_idxs: list[int] = [] - idx = int(torch.randint(high=self.n_observations, size=(1,)).item()) for _ in tqdm(range(self.coreset_size), desc="Selecting Coreset Indices."): - self.update_distances(cluster_centers=[idx]) + self.update_distances(cluster_center=idx) idx = self.get_new_idx() - if idx in selected_idxs: - msg = "New indices should not be in selected indices." - raise ValueError(msg) - self.min_distances[idx] = 0 - selected_coreset_idxs.append(idx) + self.min_distances.scatter_(0, idx.unsqueeze(0).unsqueeze(1), 0.0) + selected_coreset_idxs.append(int(idx.item())) return selected_coreset_idxs - def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor: + def sample_coreset(self) -> torch.Tensor: """Select coreset from the embedding. - Args: - selected_idxs (list[int] | None, optional): Indices of pre-selected - samples. Defaults to None. - Returns: torch.Tensor: Selected coreset. @@ -147,5 +144,5 @@ def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor >>> coreset.shape torch.Size([219, 1536]) """ - idxs = self.select_coreset_idxs(selected_idxs) + idxs = self.select_coreset_idxs() return self.embedding[idxs]