-
Notifications
You must be signed in to change notification settings - Fork 821
🚀 feat(model): Make coreset selection for patchcore faster #2968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cd833a9
7d637d7
55b40a9
86f1488
251df36
58a9e1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -10,7 +10,6 @@ | |||||
""" | ||||||
|
||||||
import torch | ||||||
from torch.nn import functional as F # noqa: N812 | ||||||
from tqdm import tqdm | ||||||
|
||||||
from anomalib.models.components.dimensionality_reduction import SparseRandomProjection | ||||||
|
@@ -58,84 +57,82 @@ def reset_distances(self) -> None: | |||||
"""Reset minimum distances to None.""" | ||||||
self.min_distances = None | ||||||
|
||||||
def update_distances(self, cluster_centers: list[int]) -> None: | ||||||
"""Update minimum distances given cluster centers. | ||||||
def update_distances(self, cluster_centers: int | torch.Tensor | None) -> None: | ||||||
"""Update minimum distances given a single cluster center. | ||||||
|
||||||
Args: | ||||||
cluster_centers (list[int]): Indices of cluster centers. | ||||||
cluster_centers (int | torch.Tensor | None): Index of a single cluster center. | ||||||
Can be an int, a 0-d tensor, or a 1-d tensor with shape [1]. | ||||||
|
||||||
Note: | ||||||
This method is optimized for single-center updates. Passing multiple | ||||||
indices may result in incorrect behavior or runtime errors. | ||||||
""" | ||||||
if cluster_centers: | ||||||
if cluster_centers is not None: | ||||||
centers = self.features[cluster_centers] | ||||||
|
||||||
distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1) | ||||||
# Ensure centers is a 1-d tensor for broadcasting | ||||||
centers = centers.squeeze(0) if centers.dim() == 2 else centers | ||||||
# Using torch.linalg.norm is faster than torch.Functional.pairwise_distance on | ||||||
# both CUDA and Intel-based hardware. | ||||||
distances = torch.linalg.norm(self.features - centers, ord=2, dim=1, keepdim=True) | ||||||
# Synchronize on XPU to ensure accurate progress bar display. | ||||||
if distances.device.type == "xpu": | ||||||
torch.xpu.synchronize() | ||||||
|
||||||
if self.min_distances is None: | ||||||
self.min_distances = distance | ||||||
self.min_distances = distances | ||||||
else: | ||||||
self.min_distances = torch.minimum(self.min_distances, distance) | ||||||
self.min_distances = torch.minimum(self.min_distances, distances) | ||||||
|
||||||
def get_new_idx(self) -> int: | ||||||
def get_new_idx(self) -> torch.Tensor: | ||||||
"""Get index of the next sample based on maximum minimum distance. | ||||||
|
||||||
Returns: | ||||||
int: Index of the selected sample. | ||||||
torch.Tensor: Index of the selected sample (tensor, not converted to int). | ||||||
|
||||||
Raises: | ||||||
TypeError: If `self.min_distances` is not a torch.Tensor. | ||||||
""" | ||||||
if isinstance(self.min_distances, torch.Tensor): | ||||||
idx = int(torch.argmax(self.min_distances).item()) | ||||||
_, idx = torch.max(self.min_distances.squeeze(), dim=0) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
else: | ||||||
msg = f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}" | ||||||
raise TypeError(msg) | ||||||
|
||||||
return idx | ||||||
|
||||||
def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[int]: | ||||||
"""Greedily form a coreset to minimize maximum distance to cluster centers. | ||||||
def select_coreset_idxs(self) -> list[int]: | ||||||
"""Greedily select coreset indices to minimize maximum distance to centers. | ||||||
|
||||||
Args: | ||||||
selected_idxs (list[int] | None, optional): Indices of pre-selected | ||||||
samples. Defaults to None. | ||||||
The algorithm iteratively selects points that are farthest from the already | ||||||
selected centers, starting from a random initial point. | ||||||
|
||||||
Returns: | ||||||
list[int]: Indices of samples selected to minimize distance to cluster | ||||||
centers. | ||||||
|
||||||
Raises: | ||||||
ValueError: If a newly selected index is already in `selected_idxs`. | ||||||
list[int]: Indices of samples selected to form the coreset. | ||||||
""" | ||||||
if selected_idxs is None: | ||||||
selected_idxs = [] | ||||||
|
||||||
if self.embedding.ndim == 2: | ||||||
self.model.fit(self.embedding) | ||||||
self.features = self.model.transform(self.embedding) | ||||||
self.reset_distances() | ||||||
else: | ||||||
self.features = self.embedding.reshape(self.embedding.shape[0], -1) | ||||||
self.update_distances(cluster_centers=selected_idxs) | ||||||
self.reset_distances() | ||||||
|
||||||
# random starting point | ||||||
idx = torch.randint(high=self.n_observations, size=(1,), device=self.features.device).squeeze() | ||||||
|
||||||
selected_coreset_idxs: list[int] = [] | ||||||
idx = int(torch.randint(high=self.n_observations, size=(1,)).item()) | ||||||
selected_coreset_idxs: list[torch.Tensor] = [] | ||||||
for _ in tqdm(range(self.coreset_size), desc="Selecting Coreset Indices."): | ||||||
self.update_distances(cluster_centers=[idx]) | ||||||
self.update_distances(cluster_centers=idx.unsqueeze(0)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The method signature accepts
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
idx = self.get_new_idx() | ||||||
if idx in selected_idxs: | ||||||
msg = "New indices should not be in selected indices." | ||||||
raise ValueError(msg) | ||||||
self.min_distances[idx] = 0 | ||||||
self.min_distances.scatter_(0, idx.unsqueeze(0).unsqueeze(1), 0.0) | ||||||
selected_coreset_idxs.append(idx) | ||||||
|
||||||
return selected_coreset_idxs | ||||||
return [int(tensor_idx.item()) for tensor_idx in selected_coreset_idxs] | ||||||
Comment on lines
+124
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type annotation indicates a list of tensors, but the final return statement converts these to integers. Consider using Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
|
||||||
def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor: | ||||||
def sample_coreset(self) -> torch.Tensor: | ||||||
"""Select coreset from the embedding. | ||||||
|
||||||
Args: | ||||||
selected_idxs (list[int] | None, optional): Indices of pre-selected | ||||||
samples. Defaults to None. | ||||||
|
||||||
Returns: | ||||||
torch.Tensor: Selected coreset. | ||||||
|
||||||
|
@@ -147,5 +144,5 @@ def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor | |||||
>>> coreset.shape | ||||||
torch.Size([219, 1536]) | ||||||
""" | ||||||
idxs = self.select_coreset_idxs(selected_idxs) | ||||||
idxs = self.select_coreset_idxs() | ||||||
return self.embedding[idxs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
centers.dim() == 2
assumes centers will only have 2 dimensions when it needs squeezing, but ifcluster_centers
is a 1-d tensor with shape [1],centers
could already be 1-d and not need squeezing. This could cause issues if centers has other dimensions. Consider usingcenters = centers.squeeze()
to remove all dimensions of size 1, or add more specific dimension checks.Copilot uses AI. Check for mistakes.