@@ -1468,3 +1468,115 @@ def test_istft_invalid_window_shape_2D_inputs(self):
1468
1468
fft_length ,
1469
1469
window = incorrect_window ,
1470
1470
)
1471
+
1472
+
1473
+ class HistogramTest (testing .TestCase ):
1474
+ def test_histogram_default_args (self ):
1475
+ hist_op = kmath .histogram
1476
+ input_tensor = np .random .rand (8 )
1477
+
1478
+ # Expected output
1479
+ expected_counts , expected_edges = np .histogram (input_tensor )
1480
+
1481
+ counts , edges = hist_op (input_tensor )
1482
+
1483
+ self .assertEqual (counts .shape , expected_counts .shape )
1484
+ self .assertAllClose (counts , expected_counts )
1485
+ self .assertEqual (edges .shape , expected_edges .shape )
1486
+ self .assertAllClose (edges , expected_edges )
1487
+
1488
+ def test_histogram_custom_bins (self ):
1489
+ hist_op = kmath .histogram
1490
+ input_tensor = np .random .rand (8 )
1491
+ bins = 5
1492
+
1493
+ # Expected output
1494
+ expected_counts , expected_edges = np .histogram (input_tensor , bins = bins )
1495
+
1496
+ counts , edges = hist_op (input_tensor , bins = bins )
1497
+
1498
+ self .assertEqual (counts .shape , expected_counts .shape )
1499
+ self .assertAllClose (counts , expected_counts )
1500
+ self .assertEqual (edges .shape , expected_edges .shape )
1501
+ self .assertAllClose (edges , expected_edges )
1502
+
1503
+ def test_histogram_custom_range (self ):
1504
+ hist_op = kmath .histogram
1505
+ input_tensor = np .random .rand (10 )
1506
+ range_specified = (2 , 8 )
1507
+
1508
+ # Expected output
1509
+ expected_counts , expected_edges = np .histogram (
1510
+ input_tensor , range = range_specified
1511
+ )
1512
+
1513
+ counts , edges = hist_op (input_tensor , range = range_specified )
1514
+
1515
+ self .assertEqual (counts .shape , expected_counts .shape )
1516
+ self .assertAllClose (counts , expected_counts )
1517
+ self .assertEqual (edges .shape , expected_edges .shape )
1518
+ self .assertAllClose (edges , expected_edges )
1519
+
1520
+ def test_histogram_symbolic_input (self ):
1521
+ hist_op = kmath .histogram
1522
+ input_tensor = KerasTensor (shape = (None ,), dtype = "float32" )
1523
+
1524
+ counts , edges = hist_op (input_tensor )
1525
+
1526
+ self .assertEqual (counts .shape , (10 ,))
1527
+ self .assertEqual (edges .shape , (11 ,))
1528
+
1529
+ def test_histogram_non_integer_bins_raises_error (self ):
1530
+ hist_op = kmath .histogram
1531
+ input_tensor = np .random .rand (8 )
1532
+
1533
+ with self .assertRaisesRegex (
1534
+ ValueError , "`bins` should be a non-negative integer"
1535
+ ):
1536
+ hist_op (input_tensor , bins = - 5 )
1537
+
1538
+ def test_histogram_range_validation (self ):
1539
+ hist_op = kmath .histogram
1540
+ input_tensor = np .random .rand (8 )
1541
+
1542
+ with self .assertRaisesRegex (
1543
+ ValueError , "range must be a tuple of two elements"
1544
+ ):
1545
+ hist_op (input_tensor , range = (1 ,))
1546
+
1547
+ with self .assertRaisesRegex (
1548
+ ValueError ,
1549
+ "The second element of range must be greater than the first" ,
1550
+ ):
1551
+ hist_op (input_tensor , range = (5 , 1 ))
1552
+
1553
+ def test_histogram_large_values (self ):
1554
+ hist_op = kmath .histogram
1555
+ input_tensor = np .array ([1e10 , 2e10 , 3e10 , 4e10 , 5e10 ])
1556
+
1557
+ counts , edges = hist_op (input_tensor , bins = 5 )
1558
+
1559
+ expected_counts , expected_edges = np .histogram (input_tensor , bins = 5 )
1560
+
1561
+ self .assertAllClose (counts , expected_counts )
1562
+ self .assertAllClose (edges , expected_edges )
1563
+
1564
+ def test_histogram_float_input (self ):
1565
+ hist_op = kmath .histogram
1566
+ input_tensor = np .random .rand (8 )
1567
+
1568
+ counts , edges = hist_op (input_tensor , bins = 5 )
1569
+
1570
+ expected_counts , expected_edges = np .histogram (input_tensor , bins = 5 )
1571
+
1572
+ self .assertAllClose (counts , expected_counts )
1573
+ self .assertAllClose (edges , expected_edges )
1574
+
1575
+ def test_histogram_high_dimensional_input (self ):
1576
+ hist_op = kmath .histogram
1577
+ input_tensor = np .random .rand (3 , 4 , 5 )
1578
+
1579
+ with self .assertRaisesRegex (
1580
+ ValueError , "Input tensor must be 1-dimensional"
1581
+ ):
1582
+ hist_op (input_tensor )
0 commit comments