Skip to content

Commit 63a11fc

Browse files
committed
fix: fix confusion matrix float32 problem
1 parent 3595cda commit 63a11fc

File tree

2 files changed

+37
-16
lines changed

2 files changed

+37
-16
lines changed

keras/src/metrics/iou_metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
name="total_confusion_matrix",
7070
shape=(num_classes, num_classes),
7171
initializer=initializers.Zeros(),
72+
dtype="float64",
7273
)
7374

7475
def update_state(self, y_true, y_pred, sample_weight=None):
@@ -131,7 +132,7 @@ def update_state(self, y_true, y_pred, sample_weight=None):
131132
y_pred,
132133
self.num_classes,
133134
weights=sample_weight,
134-
dtype="float32",
135+
dtype="float64",
135136
)
136137

137138
return self.total_cm.assign(self.total_cm + current_cm)

keras/src/metrics/iou_metrics_test.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
22
import pytest
33

4+
from keras.src import backend
45
from keras.src import layers
56
from keras.src import models
67
from keras.src import testing
7-
from keras.src import backend
88
from keras.src.metrics import iou_metrics as metrics
99

1010

@@ -362,7 +362,8 @@ def confusion_matrix(self, y_true, y_pred, num_classes):
362362
- num_classes: int, number of classes.
363363
364364
Returns:
365-
- conf_matrix: np.ndarray, confusion matrix of shape (num_classes, num_classes).
365+
- conf_matrix: np.ndarray, confusion matrix of shape (num_classes,
366+
num_classes).
366367
"""
367368
y_true = np.asarray(y_true)
368369
y_pred = np.asarray(y_pred)
@@ -374,7 +375,8 @@ def confusion_matrix(self, y_true, y_pred, num_classes):
374375
conf_matrix = conf_matrix.reshape((num_classes, num_classes))
375376
return conf_matrix
376377

377-
def _get_big_chunk(self):
378+
@staticmethod
379+
def _get_big_chunk():
378380
np.random.seed(14)
379381
all_y_true = np.random.choice([0, 1, 2], size=(600, 530, 530))
380382
# Generate random probabilities for each channel
@@ -389,28 +391,46 @@ def _get_big_chunk(self):
389391
tmp_pred = np.reshape(all_y_pred_arg, -1)
390392
return all_y_true, all_y_pred_arg, mean_iou_metric, tmp_true, tmp_pred
391393

392-
@pytest.mark.skipif(backend.backend() not in ["tensorflow", "torch"], reason="tensorflow/torch only test.")
394+
@pytest.mark.skipif(
395+
backend.backend() not in ["tensorflow", "torch"],
396+
reason="tensorflow/torch only test.",
397+
)
393398
def test_big_chunk_tf_torch(self):
394-
all_y_true, all_y_pred_arg, mean_iou_metric_all, tmp_true, tmp_pred = self._get_big_chunk()
399+
all_y_true, all_y_pred_arg, mean_iou_metric_all, tmp_true, tmp_pred = (
400+
self._get_big_chunk()
401+
)
395402
conf_matrix_from_keras = mean_iou_metric_all.total_cm.numpy()
396403
# Validate confusion matrices and results
397404
conf_matrix_manual = self.confusion_matrix(tmp_true, tmp_pred, 3)
398-
self.assertTrue(np.array_equal(conf_matrix_from_keras, conf_matrix_manual),
399-
msg='Confusion matrices do not match!')
405+
self.assertTrue(
406+
np.array_equal(conf_matrix_from_keras, conf_matrix_manual),
407+
msg="Confusion matrices do not match!",
408+
)
400409

401410
@pytest.mark.skipif(backend.backend() != "jax", reason="jax only test.")
402411
def test_big_chunk_jax(self):
403-
import jax.numpy as jnp
404-
all_y_true, all_y_pred_arg, mean_iou_metric_all, tmp_true, tmp_pred = self._get_big_chunk()
405-
conf_matrix_from_keras = jnp.array(mean_iou_metric_all.total_cm)
406-
# Validate confusion matrices and results
407-
conf_matrix_manual = self.confusion_matrix(tmp_true, tmp_pred, 3)
408-
self.assertTrue(np.array_equal(conf_matrix_from_keras, conf_matrix_manual),
409-
msg='Confusion matrices do not match!')
410-
412+
from jax import config
411413

414+
config.update("jax_enable_x64", False)
412415

416+
all_y_true, all_y_pred_arg, mean_iou_metric_all, tmp_true, tmp_pred = (
417+
self._get_big_chunk()
418+
)
419+
conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm)
420+
# Validate confusion matrices and results
421+
conf_matrix_manual = self.confusion_matrix(tmp_true, tmp_pred, 3)
422+
self.assertFalse(
423+
np.array_equal(conf_matrix_from_keras, conf_matrix_manual),
424+
msg="Confusion matrices do not match!",
425+
)
413426

427+
config.update("jax_enable_x64", True)
428+
_, _, mean_iou_metric_all, _, _ = self._get_big_chunk()
429+
conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm)
430+
self.assertTrue(
431+
np.array_equal(conf_matrix_from_keras, conf_matrix_manual),
432+
msg="Confusion matrices do not match!",
433+
)
414434

415435

416436
class OneHotIoUTest(testing.TestCase):

0 commit comments

Comments
 (0)