-
Notifications
You must be signed in to change notification settings - Fork 821
🚀 feat(model): Make coreset selection for patchcore faster #2968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
🚀 feat(model): Make coreset selection for patchcore faster #2968
Conversation
Signed-off-by: rajeshgangireddy <[email protected]>
…d benchmark scripts
Setting it to draft yet, as I want to check and remove if we even need the case where centers are more than 1. |
I have made some changes to cleanup dead code (multiple centers case). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the coreset selection algorithm in PatchCore by replacing inefficient distance calculations and reducing CPU-GPU memory transfer overhead. The changes aim to significantly improve performance, particularly for large datasets.
- Replaced
F.pairwise_distance()
withtorch.linalg.norm()
for faster distance calculations - Reduced CPU-GPU transfer overhead by minimizing
.item()
calls and keeping tensor operations on device - Simplified the API by removing unused
selected_idxs
parameter from coreset selection methods
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
||
distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1) | ||
# Ensure centers is a 1-d tensor for broadcasting | ||
centers = centers.squeeze(0) if centers.dim() == 2 else centers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition centers.dim() == 2
assumes centers will only have 2 dimensions when it needs squeezing, but if cluster_centers
is a 1-d tensor with shape [1], centers
could already be 1-d and not need squeezing. This could cause issues if centers has other dimensions. Consider using centers = centers.squeeze()
to remove all dimensions of size 1, or add more specific dimension checks.
centers = centers.squeeze(0) if centers.dim() == 2 else centers | |
centers = centers.squeeze() |
Copilot uses AI. Check for mistakes.
""" | ||
if isinstance(self.min_distances, torch.Tensor): | ||
idx = int(torch.argmax(self.min_distances).item()) | ||
_, idx = torch.max(self.min_distances.squeeze(), dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using squeeze()
without arguments removes all dimensions of size 1, which could cause issues if min_distances
has multiple dimensions that shouldn't be squeezed. Since the method expects a column vector (from keepdim=True
in line 78), consider using squeeze(1)
to only remove the second dimension, or verify the tensor shape before squeezing.
_, idx = torch.max(self.min_distances.squeeze(), dim=0) | |
_, idx = torch.max(self.min_distances.squeeze(1), dim=0) |
Copilot uses AI. Check for mistakes.
selected_coreset_idxs: list[torch.Tensor] = [] | ||
for _ in tqdm(range(self.coreset_size), desc="Selecting Coreset Indices."): | ||
self.update_distances(cluster_centers=[idx]) | ||
self.update_distances(cluster_centers=idx.unsqueeze(0)) | ||
idx = self.get_new_idx() | ||
if idx in selected_idxs: | ||
msg = "New indices should not be in selected indices." | ||
raise ValueError(msg) | ||
self.min_distances[idx] = 0 | ||
self.min_distances.scatter_(0, idx.unsqueeze(0).unsqueeze(1), 0.0) | ||
selected_coreset_idxs.append(idx) | ||
|
||
return selected_coreset_idxs | ||
return [int(tensor_idx.item()) for tensor_idx in selected_coreset_idxs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation indicates a list of tensors, but the final return statement converts these to integers. Consider using list[int]
as the type annotation and converting tensors to integers immediately when appending, or keep tensors throughout and convert only at return.
Copilot uses AI. Check for mistakes.
selected_coreset_idxs: list[torch.Tensor] = [] | ||
for _ in tqdm(range(self.coreset_size), desc="Selecting Coreset Indices."): | ||
self.update_distances(cluster_centers=[idx]) | ||
self.update_distances(cluster_centers=idx.unsqueeze(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method signature accepts int | torch.Tensor | None
but here you're always passing a tensor with an extra dimension. This inconsistency with the parameter handling in update_distances
could lead to confusion. Consider either updating the method to handle the tensor directly or consistently passing the same type.
self.update_distances(cluster_centers=idx.unsqueeze(0)) | |
self.update_distances(cluster_centers=idx) |
Copilot uses AI. Check for mistakes.
📝 Description
Coreset selection is the most time-consuming part of patchcore. Upon profiling, it is understood that there are two problems
urEnqueueKernelLaunch
(20.94% CPU time) andurEnqueueUSMMemcpy
(5.67% CPU time) indicate significant kernel launch overhead and memory copy operations. To this, i have tried to reduce the calls to .item() and using indices on cpu while features are on a different device.Results
These are the results on different hardware with different dataset sizes.
Numbers are averaged across 5 runs with different seeds.
Additional info
There is scope for cleaning up the code in
class KCenterGreedy
and to make it more better suited to patchcore. For example,def select_coreset_idxs(self, selected_idxs: list[int] | None = None)
is always called without any params when used in patchcore.✨ Changes
Select what type of change your PR is:
✅ Checklist
Before you submit your pull request, please make sure you have completed the following steps:
For more information about code review checklists, see the Code Review Checklist.