diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 029f5b25dd1..715969cefee 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -18,6 +18,7 @@ from typing_extensions import Literal from torchmetrics.functional.segmentation.utils import _ignore_background +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_divide @@ -154,6 +155,13 @@ def dice_score( [0.1978, 0.2804, 0.1714, 0.1915, 0.2783]]) """ + if average == "micro": + rank_zero_warn( + "dice_score metric currently defaults to `average=micro`, but will change to" + "`average=macro` in the v1.9 release." + " If you've explicitly set this parameter, you can ignore this warning.", + UserWarning, + ) _dice_score_validate_args(num_classes, include_background, average, input_format) numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format) return _dice_score_compute(numerator, denominator, average, support=support) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 54f2eedf0cf..174812e6a1f 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -23,6 +23,7 @@ _dice_score_validate_args, ) from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -116,6 +117,13 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) + if average == "micro": + rank_zero_warn( + "DiceScore metric currently defaults to `average=micro`, but will change to" + "`average=macro` in the v1.9 release." + " If you've explicitly set this parameter, you can ignore this warning.", + UserWarning, + ) _dice_score_validate_args(num_classes, include_background, average, input_format, zero_division) self.num_classes = num_classes self.include_background = include_background