1
1
import numpy as np
2
2
import pytest
3
3
4
+ from keras .src import backend
4
5
from keras .src import layers
5
6
from keras .src import models
6
7
from keras .src import testing
7
- from keras .src import backend
8
8
from keras .src .metrics import iou_metrics as metrics
9
9
10
10
@@ -362,7 +362,8 @@ def confusion_matrix(self, y_true, y_pred, num_classes):
362
362
- num_classes: int, number of classes.
363
363
364
364
Returns:
365
- - conf_matrix: np.ndarray, confusion matrix of shape (num_classes, num_classes).
365
+ - conf_matrix: np.ndarray, confusion matrix of shape (num_classes,
366
+ num_classes).
366
367
"""
367
368
y_true = np .asarray (y_true )
368
369
y_pred = np .asarray (y_pred )
@@ -374,7 +375,8 @@ def confusion_matrix(self, y_true, y_pred, num_classes):
374
375
conf_matrix = conf_matrix .reshape ((num_classes , num_classes ))
375
376
return conf_matrix
376
377
377
- def _get_big_chunk (self ):
378
+ @staticmethod
379
+ def _get_big_chunk ():
378
380
np .random .seed (14 )
379
381
all_y_true = np .random .choice ([0 , 1 , 2 ], size = (600 , 530 , 530 ))
380
382
# Generate random probabilities for each channel
@@ -389,28 +391,46 @@ def _get_big_chunk(self):
389
391
tmp_pred = np .reshape (all_y_pred_arg , - 1 )
390
392
return all_y_true , all_y_pred_arg , mean_iou_metric , tmp_true , tmp_pred
391
393
392
- @pytest .mark .skipif (backend .backend () not in ["tensorflow" , "torch" ], reason = "tensorflow/torch only test." )
394
+ @pytest .mark .skipif (
395
+ backend .backend () not in ["tensorflow" , "torch" ],
396
+ reason = "tensorflow/torch only test." ,
397
+ )
393
398
def test_big_chunk_tf_torch (self ):
394
- all_y_true , all_y_pred_arg , mean_iou_metric_all , tmp_true , tmp_pred = self ._get_big_chunk ()
399
+ all_y_true , all_y_pred_arg , mean_iou_metric_all , tmp_true , tmp_pred = (
400
+ self ._get_big_chunk ()
401
+ )
395
402
conf_matrix_from_keras = mean_iou_metric_all .total_cm .numpy ()
396
403
# Validate confusion matrices and results
397
404
conf_matrix_manual = self .confusion_matrix (tmp_true , tmp_pred , 3 )
398
- self .assertTrue (np .array_equal (conf_matrix_from_keras , conf_matrix_manual ),
399
- msg = 'Confusion matrices do not match!' )
405
+ self .assertTrue (
406
+ np .array_equal (conf_matrix_from_keras , conf_matrix_manual ),
407
+ msg = "Confusion matrices do not match!" ,
408
+ )
400
409
401
410
@pytest .mark .skipif (backend .backend () != "jax" , reason = "jax only test." )
402
411
def test_big_chunk_jax (self ):
403
- import jax .numpy as jnp
404
- all_y_true , all_y_pred_arg , mean_iou_metric_all , tmp_true , tmp_pred = self ._get_big_chunk ()
405
- conf_matrix_from_keras = jnp .array (mean_iou_metric_all .total_cm )
406
- # Validate confusion matrices and results
407
- conf_matrix_manual = self .confusion_matrix (tmp_true , tmp_pred , 3 )
408
- self .assertTrue (np .array_equal (conf_matrix_from_keras , conf_matrix_manual ),
409
- msg = 'Confusion matrices do not match!' )
410
-
412
+ from jax import config
411
413
414
+ config .update ("jax_enable_x64" , False )
412
415
416
+ all_y_true , all_y_pred_arg , mean_iou_metric_all , tmp_true , tmp_pred = (
417
+ self ._get_big_chunk ()
418
+ )
419
+ conf_matrix_from_keras = np .array (mean_iou_metric_all .total_cm )
420
+ # Validate confusion matrices and results
421
+ conf_matrix_manual = self .confusion_matrix (tmp_true , tmp_pred , 3 )
422
+ self .assertFalse (
423
+ np .array_equal (conf_matrix_from_keras , conf_matrix_manual ),
424
+ msg = "Confusion matrices do not match!" ,
425
+ )
413
426
427
+ config .update ("jax_enable_x64" , True )
428
+ _ , _ , mean_iou_metric_all , _ , _ = self ._get_big_chunk ()
429
+ conf_matrix_from_keras = np .array (mean_iou_metric_all .total_cm )
430
+ self .assertTrue (
431
+ np .array_equal (conf_matrix_from_keras , conf_matrix_manual ),
432
+ msg = "Confusion matrices do not match!" ,
433
+ )
414
434
415
435
416
436
class OneHotIoUTest (testing .TestCase ):
0 commit comments