Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from keras.src.ops.math import extract_sequences
from keras.src.ops.math import fft
from keras.src.ops.math import fft2
from keras.src.ops.math import histogram
from keras.src.ops.math import in_top_k
from keras.src.ops.math import irfft
from keras.src.ops.math import istft
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from keras.src.ops.math import extract_sequences
from keras.src.ops.math import fft
from keras.src.ops.math import fft2
from keras.src.ops.math import histogram
from keras.src.ops.math import in_top_k
from keras.src.ops.math import irfft
from keras.src.ops.math import istft
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,7 @@ def logdet(x):
# `np.log(np.linalg.det(x))`. See
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
return slogdet(x)[1]


def histogram(x, bins, range):
return jnp.histogram(x, bins=bins, range=range)
4 changes: 4 additions & 0 deletions keras/src/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,7 @@ def logdet(x):
# In NumPy slogdet is more stable than `np.log(np.linalg.det(x))`. See
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
return slogdet(x)[1]


def histogram(x, bins, range):
return np.histogram(x, bins=bins, range=range)
28 changes: 28 additions & 0 deletions keras/src/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,31 @@ def norm(x, ord=None, axis=None, keepdims=False):
def logdet(x):
x = convert_to_tensor(x)
return tf.linalg.logdet(x)


def histogram(x, bins, range):
"""
Computes a histogram of the data tensor `x` using TensorFlow.
The `tf.histogram_fixed_width()` and `tf.histogram_fixed_width_bins()`
methods yielded slight numerical differences on some edge cases.
"""

x = tf.convert_to_tensor(x, dtype=x.dtype)

# Handle the range argument
if range is None:
min_val = tf.reduce_min(x)
max_val = tf.reduce_max(x)
else:
min_val, max_val = range

x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
bin_edges = tf.linspace(min_val, max_val, bins + 1)
bin_edges_list = bin_edges.numpy().tolist()
bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])

bin_counts = tf.math.bincount(
bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype
)

return bin_counts, bin_edges
5 changes: 5 additions & 0 deletions keras/src/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,8 @@ def norm(x, ord=None, axis=None, keepdims=False):
def logdet(x):
x = convert_to_tensor(x)
return torch.logdet(x)


def histogram(x, bins, range):
hist_result = torch.histogram(x, bins=bins, range=range)
return hist_result.hist, hist_result.bin_edges
86 changes: 86 additions & 0 deletions keras/src/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,3 +971,89 @@ def logdet(x):
if any_symbolic_tensors((x,)):
return Logdet().symbolic_call(x)
return backend.math.logdet(x)


class Histogram(Operation):
def __init__(self, bins=10, range=None):
super().__init__()

if not isinstance(bins, int):
raise TypeError("bins must be of type `int`")
if bins < 0:
raise ValueError("`bins` should be a non-negative integer")

if range:
if len(range) < 2 or not isinstance(range, tuple):
raise ValueError("range must be a tuple of two elements")

if range[1] < range[0]:
raise ValueError(
"The second element of range must be greater than the first"
)

self.bins = bins
self.range = range

def call(self, x):
x = backend.convert_to_tensor(x)
if len(x.shape) > 1:
raise ValueError("Input tensor must be 1-dimensional")
return backend.math.histogram(x, bins=self.bins, range=self.range)

def compute_output_spec(self, x):
return (
KerasTensor(shape=(self.bins,), dtype=x.dtype),
KerasTensor(shape=(self.bins + 1,), dtype=x.dtype),
)


@keras_export("keras.ops.histogram")
def histogram(x, bins=10, range=None):
"""Computes a histogram of the data tensor `x`.

Args:
x: Input tensor.
bins: An integer representing the number of histogram bins.
Defaults to 10.
range: A tuple representing the lower and upper range of the bins.
If not specified, it will use the min and max of `x`.

Returns:
A tuple containing:
- A tensor representing the counts of elements in each bin.
- A tensor representing the bin edges.

Example:

```
>>> nput_tensor = np.random.rand(8)
>>> keras.ops.histogram(input_tensor)
(array([1, 1, 1, 0, 0, 1, 2, 1, 0, 1], dtype=int32),
array([0.0189519 , 0.10294958, 0.18694726, 0.27094494, 0.35494262,
0.43894029, 0.52293797, 0.60693565, 0.69093333, 0.77493101,
0.85892869]))
```

"""

if not isinstance(bins, int):
raise TypeError("bins must be of type `int`")
if bins < 0:
raise ValueError("`bins` should be a non-negative integer")

if range:
if len(range) < 2 or not isinstance(range, tuple):
raise ValueError("range must be a tuple of two elements")

if range[1] < range[0]:
raise ValueError(
"The second element of range must be greater than the first"
)

if any_symbolic_tensors((x,)):
return Histogram(bins=bins, range=range).symbolic_call(x)

x = backend.convert_to_tensor(x)
if len(x.shape) > 1:
raise ValueError("Input tensor must be 1-dimensional")
return backend.math.histogram(x, bins=bins, range=range)
112 changes: 112 additions & 0 deletions keras/src/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,3 +1468,115 @@ def test_istft_invalid_window_shape_2D_inputs(self):
fft_length,
window=incorrect_window,
)


class HistogramTest(testing.TestCase):
def test_histogram_default_args(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

# Expected output
expected_counts, expected_edges = np.histogram(input_tensor)

counts, edges = hist_op(input_tensor)

self.assertEqual(counts.shape, expected_counts.shape)
self.assertAllClose(counts, expected_counts)
self.assertEqual(edges.shape, expected_edges.shape)
self.assertAllClose(edges, expected_edges)

def test_histogram_custom_bins(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)
bins = 5

# Expected output
expected_counts, expected_edges = np.histogram(input_tensor, bins=bins)

counts, edges = hist_op(input_tensor, bins=bins)

self.assertEqual(counts.shape, expected_counts.shape)
self.assertAllClose(counts, expected_counts)
self.assertEqual(edges.shape, expected_edges.shape)
self.assertAllClose(edges, expected_edges)

def test_histogram_custom_range(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(10)
range_specified = (2, 8)

# Expected output
expected_counts, expected_edges = np.histogram(
input_tensor, range=range_specified
)

counts, edges = hist_op(input_tensor, range=range_specified)

self.assertEqual(counts.shape, expected_counts.shape)
self.assertAllClose(counts, expected_counts)
self.assertEqual(edges.shape, expected_edges.shape)
self.assertAllClose(edges, expected_edges)

def test_histogram_symbolic_input(self):
hist_op = kmath.histogram
input_tensor = KerasTensor(shape=(None,), dtype="float32")

counts, edges = hist_op(input_tensor)

self.assertEqual(counts.shape, (10,))
self.assertEqual(edges.shape, (11,))

def test_histogram_non_integer_bins_raises_error(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

with self.assertRaisesRegex(
ValueError, "`bins` should be a non-negative integer"
):
hist_op(input_tensor, bins=-5)

def test_histogram_range_validation(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

with self.assertRaisesRegex(
ValueError, "range must be a tuple of two elements"
):
hist_op(input_tensor, range=(1,))

with self.assertRaisesRegex(
ValueError,
"The second element of range must be greater than the first",
):
hist_op(input_tensor, range=(5, 1))

def test_histogram_large_values(self):
hist_op = kmath.histogram
input_tensor = np.array([1e10, 2e10, 3e10, 4e10, 5e10])

counts, edges = hist_op(input_tensor, bins=5)

expected_counts, expected_edges = np.histogram(input_tensor, bins=5)

self.assertAllClose(counts, expected_counts)
self.assertAllClose(edges, expected_edges)

def test_histogram_float_input(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

counts, edges = hist_op(input_tensor, bins=5)

expected_counts, expected_edges = np.histogram(input_tensor, bins=5)

self.assertAllClose(counts, expected_counts)
self.assertAllClose(edges, expected_edges)

def test_histogram_high_dimensional_input(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(3, 4, 5)

with self.assertRaisesRegex(
ValueError, "Input tensor must be 1-dimensional"
):
hist_op(input_tensor)