-
Notifications
You must be signed in to change notification settings - Fork 463
feat: Compute per-query average for 2D retrieval_normalized_dcg
#3229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: Compute per-query average for 2D retrieval_normalized_dcg
#3229
Conversation
retrieval_normalized_dcg
Hi, @Borda Although this PR focuses on fixing For example, with Average Precision: from torchmetrics.functional.retrieval import retrieval_average_precision
import torch
# Query 1
p1 = retrieval_average_precision(torch.tensor([0.1, 0.2, 0.3]), torch.tensor([0, 1, 0]))
print(p1) # tensor(0.5000)
# Query 2
p2 = retrieval_average_precision(torch.tensor([0.8, 0.1, 0.05]), torch.tensor([1, 0, 0]))
print(p2) # tensor(1.0000)
print("Mean per-query Precision:", (p1 + p2) / 2)
# tensor(0.7500)
# Batched input (2D)
p_batch = retrieval_average_precision(
torch.tensor([[0.1, 0.2, 0.3], [0.8, 0.1, 0.05]]),
torch.tensor([[0, 1, 0], [1, 0, 0]]),
)
print("Batch Precision:", p_batch)
# tensor(0.8333) <-- Not the mean per-query value (same as the NDCG example) Here, the batched input is incorrectly treated as one large query with num_queries * num_documents documents, which inflates the reported metric compared to the mean per-query value. To keep the scope manageable, my current plan is to limit this PR to NDCG only, adding tests that demonstrate the correct per-query behavior, so it is easier to review.
Do you think this level of granularity makes sense, or would you prefer the broader retrieval-wide fix to be included directly in this PR? |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #3229 +/- ##
=======================================
- Coverage 69% 68% -0%
=======================================
Files 364 349 -15
Lines 20096 19923 -173
=======================================
- Hits 13790 13605 -185
- Misses 6306 6318 +12 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes the NDCG computation for 2D tensor inputs to correctly compute per-query averages instead of treating all queries as a single ranking problem. Previously, 2D inputs were flattened and computed as one large ranking, but now each query is processed separately and averaged.
- Updated
retrieval_normalized_dcg
to handle 2D tensors as batches of samples with per-query computation - Added
empty_target_action
parameter to handle queries with no positive labels - Enhanced test infrastructure to support 2D tensor testing and the new parameter
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
src/torchmetrics/functional/retrieval/ndcg.py | Core logic changes for 2D tensor handling and empty target handling |
tests/unittests/retrieval/_inputs.py | Added 2D test input data structure |
tests/unittests/retrieval/helpers.py | Updated test helpers to support 2D inputs and new metric parameter |
tests/unittests/retrieval/test_ndcg.py | Added test coverage for the new empty_target_action parameter |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
original_shape = preds.shape | ||
preds, target = _check_retrieval_functional_inputs(preds, target, allow_non_binary_target=True) | ||
|
||
top_k = preds.shape[-1] if top_k is None else top_k | ||
# reshape back if input was 2D | ||
if len(original_shape) == 2: | ||
preds = preds.view(original_shape) | ||
target = target.view(original_shape) | ||
else: | ||
preds = preds.unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 143 should use target.unsqueeze(0)
instead of target.view(original_shape)
. The else branch handles 1D inputs which need to be unsqueezed to match the preds tensor on line 142.
Copilot uses AI. Check for mistakes.
What does this PR do?
Fixes #3216
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Motivation
Currently, when passing a 2D tensor ([num_queries, num_documents]) to
retrieval_normalized_dcg
, the function flattens both preds and target and computes DCG/IDCG on the concatenated list.This treats all queries as a single large ranking problem.
Changes
Fixes to
retrieval_normalized_dcg
:Test additions and modifications:
ndcg_score
to align conditions with our implementation.For the reviewer
Feedback on the naming and default behavior of the new
empty_target_action
argument would be greatly appreciated. Please also let us know if any test coverage is missing.Did you have fun?
Make sure you had fun coding 🙃
📚 Documentation preview 📚: https://torchmetrics--3229.org.readthedocs.build/en/3229/