Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/torchmetrics/functional/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing_extensions import Literal

from torchmetrics.functional.segmentation.utils import _ignore_background
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide

Expand Down Expand Up @@ -154,6 +155,13 @@ def dice_score(
[0.1978, 0.2804, 0.1714, 0.1915, 0.2783]])

"""
if average == "micro":
rank_zero_warn(
"dice_score metric currently defaults to `average=micro`, but will change to"
"`average=macro` in the v1.9 release."
" If you've explicitly set this parameter, you can ignore this warning.",
UserWarning,
)
_dice_score_validate_args(num_classes, include_background, average, input_format)
numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format)
return _dice_score_compute(numerator, denominator, average, support=support)
8 changes: 8 additions & 0 deletions src/torchmetrics/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dice_score_validate_args,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
Expand Down Expand Up @@ -116,6 +117,13 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if average == "micro":
rank_zero_warn(
"DiceScore metric currently defaults to `average=micro`, but will change to"
"`average=macro` in the v1.9 release."
" If you've explicitly set this parameter, you can ignore this warning.",
UserWarning,
)
_dice_score_validate_args(num_classes, include_background, average, input_format, zero_division)
self.num_classes = num_classes
self.include_background = include_background
Expand Down
Loading