Skip to content

Commit 61d85f3

Browse files
Fix istft and add class TestMathErrors in ops/math_test.py (#19594)
* Fix and test math functions for jax backend * run /workspaces/keras/shell/format.sh * refix * fix * fix _get_complex_tensor_from_tuple * fix * refix * Fix istft function to handle inputs with less than 2 dimensions * fix * Fix ValueError in istft function for inputs with less than 2 dimensions
1 parent 4cb5671 commit 61d85f3

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

keras/src/backend/jax/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ def istft(
204204
x = _get_complex_tensor_from_tuple(x)
205205
dtype = jnp.real(x).dtype
206206

207+
if len(x.shape) < 2:
208+
raise ValueError(
209+
f"Input `x` must have at least 2 dimensions. "
210+
f"Received shape: {x.shape}"
211+
)
212+
207213
expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1)
208214
l_pad = (fft_length - sequence_length) // 2
209215
r_pad = fft_length - sequence_length - l_pad

keras/src/ops/linalg_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ def test_qr(self):
101101
self.assertEqual(q.shape, qref_shape)
102102
self.assertEqual(r.shape, rref_shape)
103103

104+
def test_qr_invalid_mode(self):
105+
# backend agnostic error message
106+
x = np.array([[1, 2], [3, 4]])
107+
invalid_mode = "invalid_mode"
108+
with self.assertRaisesRegex(
109+
ValueError, "Expected one of {'reduced', 'complete'}."
110+
):
111+
linalg.qr(x, mode=invalid_mode)
112+
104113
def test_solve(self):
105114
a = KerasTensor([None, 20, 20])
106115
b = KerasTensor([None, 20, 5])

keras/src/ops/math_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22

3+
import jax.numpy as jnp
34
import numpy as np
45
import pytest
56
import scipy.signal
@@ -1256,3 +1257,90 @@ def test_undefined_fft_length_and_last_dimension(self):
12561257
expected_shape = real_part.shape[:-1] + (None,)
12571258

12581259
self.assertEqual(output_spec.shape, expected_shape)
1260+
1261+
1262+
class TestMathErrors(testing.TestCase):
1263+
1264+
@pytest.mark.skipif(
1265+
backend.backend() != "jax", reason="Testing Jax errors only"
1266+
)
1267+
def test_segment_sum_no_num_segments(self):
1268+
data = jnp.array([1, 2, 3, 4])
1269+
segment_ids = jnp.array([0, 0, 1, 1])
1270+
with self.assertRaisesRegex(
1271+
ValueError,
1272+
"Argument `num_segments` must be set when using the JAX backend.",
1273+
):
1274+
kmath.segment_sum(data, segment_ids)
1275+
1276+
@pytest.mark.skipif(
1277+
backend.backend() != "jax", reason="Testing Jax errors only"
1278+
)
1279+
def test_segment_max_no_num_segments(self):
1280+
data = jnp.array([1, 2, 3, 4])
1281+
segment_ids = jnp.array([0, 0, 1, 1])
1282+
with self.assertRaisesRegex(
1283+
ValueError,
1284+
"Argument `num_segments` must be set when using the JAX backend.",
1285+
):
1286+
kmath.segment_max(data, segment_ids)
1287+
1288+
def test_stft_invalid_input_type(self):
1289+
# backend agnostic error message
1290+
x = np.array([1, 2, 3, 4])
1291+
sequence_length = 2
1292+
sequence_stride = 1
1293+
fft_length = 4
1294+
with self.assertRaisesRegex(TypeError, "`float32` or `float64`"):
1295+
kmath.stft(x, sequence_length, sequence_stride, fft_length)
1296+
1297+
def test_invalid_fft_length(self):
1298+
# backend agnostic error message
1299+
x = np.array([1.0, 2.0, 3.0, 4.0])
1300+
sequence_length = 4
1301+
sequence_stride = 1
1302+
fft_length = 2
1303+
with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"):
1304+
kmath.stft(x, sequence_length, sequence_stride, fft_length)
1305+
1306+
def test_stft_invalid_window(self):
1307+
# backend agnostic error message
1308+
x = np.array([1.0, 2.0, 3.0, 4.0])
1309+
sequence_length = 2
1310+
sequence_stride = 1
1311+
fft_length = 4
1312+
window = "invalid_window"
1313+
with self.assertRaisesRegex(ValueError, "If a string is passed to"):
1314+
kmath.stft(
1315+
x, sequence_length, sequence_stride, fft_length, window=window
1316+
)
1317+
1318+
def test_stft_invalid_window_shape(self):
1319+
# backend agnostic error message
1320+
x = np.array([1.0, 2.0, 3.0, 4.0])
1321+
sequence_length = 2
1322+
sequence_stride = 1
1323+
fft_length = 4
1324+
window = np.ones((sequence_length + 1))
1325+
with self.assertRaisesRegex(ValueError, "The shape of `window` must"):
1326+
kmath.stft(
1327+
x, sequence_length, sequence_stride, fft_length, window=window
1328+
)
1329+
1330+
def test_istft_invalid_window_shape_2D_inputs(self):
1331+
# backend agnostic error message
1332+
x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]]))
1333+
sequence_length = 2
1334+
sequence_stride = 1
1335+
fft_length = 4
1336+
incorrect_window = np.ones((sequence_length + 1,))
1337+
with self.assertRaisesRegex(
1338+
ValueError, "The shape of `window` must be equal to"
1339+
):
1340+
kmath.istft(
1341+
x,
1342+
sequence_length,
1343+
sequence_stride,
1344+
fft_length,
1345+
window=incorrect_window,
1346+
)

0 commit comments

Comments
 (0)