Skip to content

Commit 12125c0

Browse files
dice score add warnings (#3041)
* dice score add warnings * Apply suggestions from code review --------- Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent e91ced8 commit 12125c0

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/torchmetrics/functional/segmentation/dice.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing_extensions import Literal
1919

2020
from torchmetrics.functional.segmentation.utils import _ignore_background
21+
from torchmetrics.utilities import rank_zero_warn
2122
from torchmetrics.utilities.checks import _check_same_shape
2223
from torchmetrics.utilities.compute import _safe_divide
2324

@@ -154,6 +155,13 @@ def dice_score(
154155
[0.1978, 0.2804, 0.1714, 0.1915, 0.2783]])
155156
156157
"""
158+
if average == "micro":
159+
rank_zero_warn(
160+
"dice_score metric currently defaults to `average=micro`, but will change to"
161+
"`average=macro` in the v1.9 release."
162+
" If you've explicitly set this parameter, you can ignore this warning.",
163+
UserWarning,
164+
)
157165
_dice_score_validate_args(num_classes, include_background, average, input_format)
158166
numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format)
159167
return _dice_score_compute(numerator, denominator, average, support=support)

src/torchmetrics/segmentation/dice.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_dice_score_validate_args,
2424
)
2525
from torchmetrics.metric import Metric
26+
from torchmetrics.utilities import rank_zero_warn
2627
from torchmetrics.utilities.data import dim_zero_cat
2728
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
2829
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
@@ -116,6 +117,13 @@ def __init__(
116117
**kwargs: Any,
117118
) -> None:
118119
super().__init__(**kwargs)
120+
if average == "micro":
121+
rank_zero_warn(
122+
"DiceScore metric currently defaults to `average=micro`, but will change to"
123+
"`average=macro` in the v1.9 release."
124+
" If you've explicitly set this parameter, you can ignore this warning.",
125+
UserWarning,
126+
)
119127
_dice_score_validate_args(num_classes, include_background, average, input_format, zero_division)
120128
self.num_classes = num_classes
121129
self.include_background = include_background

0 commit comments

Comments
 (0)