Skip to content

Commit 08910e2

Browse files
authored
[Keras Ops] Add Histogram Operation (#20316)
* add histogram operation to keras.ops * update docstrings * extract values from torch op
1 parent 084b7e1 commit 08910e2

File tree

8 files changed

+241
-0
lines changed

8 files changed

+241
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from keras.src.ops.math import extract_sequences
4848
from keras.src.ops.math import fft
4949
from keras.src.ops.math import fft2
50+
from keras.src.ops.math import histogram
5051
from keras.src.ops.math import in_top_k
5152
from keras.src.ops.math import irfft
5253
from keras.src.ops.math import istft

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from keras.src.ops.math import extract_sequences
4848
from keras.src.ops.math import fft
4949
from keras.src.ops.math import fft2
50+
from keras.src.ops.math import histogram
5051
from keras.src.ops.math import in_top_k
5152
from keras.src.ops.math import irfft
5253
from keras.src.ops.math import istft

keras/src/backend/jax/math.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,7 @@ def logdet(x):
294294
# `np.log(np.linalg.det(x))`. See
295295
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
296296
return slogdet(x)[1]
297+
298+
299+
def histogram(x, bins, range):
300+
return jnp.histogram(x, bins=bins, range=range)

keras/src/backend/numpy/math.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,7 @@ def logdet(x):
316316
# In NumPy slogdet is more stable than `np.log(np.linalg.det(x))`. See
317317
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
318318
return slogdet(x)[1]
319+
320+
321+
def histogram(x, bins, range):
322+
return np.histogram(x, bins=bins, range=range)

keras/src/backend/tensorflow/math.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,31 @@ def norm(x, ord=None, axis=None, keepdims=False):
370370
def logdet(x):
371371
x = convert_to_tensor(x)
372372
return tf.linalg.logdet(x)
373+
374+
375+
def histogram(x, bins, range):
376+
"""
377+
Computes a histogram of the data tensor `x` using TensorFlow.
378+
The `tf.histogram_fixed_width()` and `tf.histogram_fixed_width_bins()`
379+
methods yielded slight numerical differences on some edge cases.
380+
"""
381+
382+
x = tf.convert_to_tensor(x, dtype=x.dtype)
383+
384+
# Handle the range argument
385+
if range is None:
386+
min_val = tf.reduce_min(x)
387+
max_val = tf.reduce_max(x)
388+
else:
389+
min_val, max_val = range
390+
391+
x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
392+
bin_edges = tf.linspace(min_val, max_val, bins + 1)
393+
bin_edges_list = bin_edges.numpy().tolist()
394+
bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])
395+
396+
bin_counts = tf.math.bincount(
397+
bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype
398+
)
399+
400+
return bin_counts, bin_edges

keras/src/backend/torch/math.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,8 @@ def norm(x, ord=None, axis=None, keepdims=False):
419419
def logdet(x):
420420
x = convert_to_tensor(x)
421421
return torch.logdet(x)
422+
423+
424+
def histogram(x, bins, range):
425+
hist_result = torch.histogram(x, bins=bins, range=range)
426+
return hist_result.hist, hist_result.bin_edges

keras/src/ops/math.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,3 +971,89 @@ def logdet(x):
971971
if any_symbolic_tensors((x,)):
972972
return Logdet().symbolic_call(x)
973973
return backend.math.logdet(x)
974+
975+
976+
class Histogram(Operation):
977+
def __init__(self, bins=10, range=None):
978+
super().__init__()
979+
980+
if not isinstance(bins, int):
981+
raise TypeError("bins must be of type `int`")
982+
if bins < 0:
983+
raise ValueError("`bins` should be a non-negative integer")
984+
985+
if range:
986+
if len(range) < 2 or not isinstance(range, tuple):
987+
raise ValueError("range must be a tuple of two elements")
988+
989+
if range[1] < range[0]:
990+
raise ValueError(
991+
"The second element of range must be greater than the first"
992+
)
993+
994+
self.bins = bins
995+
self.range = range
996+
997+
def call(self, x):
998+
x = backend.convert_to_tensor(x)
999+
if len(x.shape) > 1:
1000+
raise ValueError("Input tensor must be 1-dimensional")
1001+
return backend.math.histogram(x, bins=self.bins, range=self.range)
1002+
1003+
def compute_output_spec(self, x):
1004+
return (
1005+
KerasTensor(shape=(self.bins,), dtype=x.dtype),
1006+
KerasTensor(shape=(self.bins + 1,), dtype=x.dtype),
1007+
)
1008+
1009+
1010+
@keras_export("keras.ops.histogram")
1011+
def histogram(x, bins=10, range=None):
1012+
"""Computes a histogram of the data tensor `x`.
1013+
1014+
Args:
1015+
x: Input tensor.
1016+
bins: An integer representing the number of histogram bins.
1017+
Defaults to 10.
1018+
range: A tuple representing the lower and upper range of the bins.
1019+
If not specified, it will use the min and max of `x`.
1020+
1021+
Returns:
1022+
A tuple containing:
1023+
- A tensor representing the counts of elements in each bin.
1024+
- A tensor representing the bin edges.
1025+
1026+
Example:
1027+
1028+
```
1029+
>>> nput_tensor = np.random.rand(8)
1030+
>>> keras.ops.histogram(input_tensor)
1031+
(array([1, 1, 1, 0, 0, 1, 2, 1, 0, 1], dtype=int32),
1032+
array([0.0189519 , 0.10294958, 0.18694726, 0.27094494, 0.35494262,
1033+
0.43894029, 0.52293797, 0.60693565, 0.69093333, 0.77493101,
1034+
0.85892869]))
1035+
```
1036+
1037+
"""
1038+
1039+
if not isinstance(bins, int):
1040+
raise TypeError("bins must be of type `int`")
1041+
if bins < 0:
1042+
raise ValueError("`bins` should be a non-negative integer")
1043+
1044+
if range:
1045+
if len(range) < 2 or not isinstance(range, tuple):
1046+
raise ValueError("range must be a tuple of two elements")
1047+
1048+
if range[1] < range[0]:
1049+
raise ValueError(
1050+
"The second element of range must be greater than the first"
1051+
)
1052+
1053+
if any_symbolic_tensors((x,)):
1054+
return Histogram(bins=bins, range=range).symbolic_call(x)
1055+
1056+
x = backend.convert_to_tensor(x)
1057+
if len(x.shape) > 1:
1058+
raise ValueError("Input tensor must be 1-dimensional")
1059+
return backend.math.histogram(x, bins=bins, range=range)

keras/src/ops/math_test.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,3 +1468,115 @@ def test_istft_invalid_window_shape_2D_inputs(self):
14681468
fft_length,
14691469
window=incorrect_window,
14701470
)
1471+
1472+
1473+
class HistogramTest(testing.TestCase):
1474+
def test_histogram_default_args(self):
1475+
hist_op = kmath.histogram
1476+
input_tensor = np.random.rand(8)
1477+
1478+
# Expected output
1479+
expected_counts, expected_edges = np.histogram(input_tensor)
1480+
1481+
counts, edges = hist_op(input_tensor)
1482+
1483+
self.assertEqual(counts.shape, expected_counts.shape)
1484+
self.assertAllClose(counts, expected_counts)
1485+
self.assertEqual(edges.shape, expected_edges.shape)
1486+
self.assertAllClose(edges, expected_edges)
1487+
1488+
def test_histogram_custom_bins(self):
1489+
hist_op = kmath.histogram
1490+
input_tensor = np.random.rand(8)
1491+
bins = 5
1492+
1493+
# Expected output
1494+
expected_counts, expected_edges = np.histogram(input_tensor, bins=bins)
1495+
1496+
counts, edges = hist_op(input_tensor, bins=bins)
1497+
1498+
self.assertEqual(counts.shape, expected_counts.shape)
1499+
self.assertAllClose(counts, expected_counts)
1500+
self.assertEqual(edges.shape, expected_edges.shape)
1501+
self.assertAllClose(edges, expected_edges)
1502+
1503+
def test_histogram_custom_range(self):
1504+
hist_op = kmath.histogram
1505+
input_tensor = np.random.rand(10)
1506+
range_specified = (2, 8)
1507+
1508+
# Expected output
1509+
expected_counts, expected_edges = np.histogram(
1510+
input_tensor, range=range_specified
1511+
)
1512+
1513+
counts, edges = hist_op(input_tensor, range=range_specified)
1514+
1515+
self.assertEqual(counts.shape, expected_counts.shape)
1516+
self.assertAllClose(counts, expected_counts)
1517+
self.assertEqual(edges.shape, expected_edges.shape)
1518+
self.assertAllClose(edges, expected_edges)
1519+
1520+
def test_histogram_symbolic_input(self):
1521+
hist_op = kmath.histogram
1522+
input_tensor = KerasTensor(shape=(None,), dtype="float32")
1523+
1524+
counts, edges = hist_op(input_tensor)
1525+
1526+
self.assertEqual(counts.shape, (10,))
1527+
self.assertEqual(edges.shape, (11,))
1528+
1529+
def test_histogram_non_integer_bins_raises_error(self):
1530+
hist_op = kmath.histogram
1531+
input_tensor = np.random.rand(8)
1532+
1533+
with self.assertRaisesRegex(
1534+
ValueError, "`bins` should be a non-negative integer"
1535+
):
1536+
hist_op(input_tensor, bins=-5)
1537+
1538+
def test_histogram_range_validation(self):
1539+
hist_op = kmath.histogram
1540+
input_tensor = np.random.rand(8)
1541+
1542+
with self.assertRaisesRegex(
1543+
ValueError, "range must be a tuple of two elements"
1544+
):
1545+
hist_op(input_tensor, range=(1,))
1546+
1547+
with self.assertRaisesRegex(
1548+
ValueError,
1549+
"The second element of range must be greater than the first",
1550+
):
1551+
hist_op(input_tensor, range=(5, 1))
1552+
1553+
def test_histogram_large_values(self):
1554+
hist_op = kmath.histogram
1555+
input_tensor = np.array([1e10, 2e10, 3e10, 4e10, 5e10])
1556+
1557+
counts, edges = hist_op(input_tensor, bins=5)
1558+
1559+
expected_counts, expected_edges = np.histogram(input_tensor, bins=5)
1560+
1561+
self.assertAllClose(counts, expected_counts)
1562+
self.assertAllClose(edges, expected_edges)
1563+
1564+
def test_histogram_float_input(self):
1565+
hist_op = kmath.histogram
1566+
input_tensor = np.random.rand(8)
1567+
1568+
counts, edges = hist_op(input_tensor, bins=5)
1569+
1570+
expected_counts, expected_edges = np.histogram(input_tensor, bins=5)
1571+
1572+
self.assertAllClose(counts, expected_counts)
1573+
self.assertAllClose(edges, expected_edges)
1574+
1575+
def test_histogram_high_dimensional_input(self):
1576+
hist_op = kmath.histogram
1577+
input_tensor = np.random.rand(3, 4, 5)
1578+
1579+
with self.assertRaisesRegex(
1580+
ValueError, "Input tensor must be 1-dimensional"
1581+
):
1582+
hist_op(input_tensor)

0 commit comments

Comments
 (0)