|
1 | 1 | import math
|
2 | 2 |
|
| 3 | +import jax.numpy as jnp |
3 | 4 | import numpy as np
|
4 | 5 | import pytest
|
5 | 6 | import scipy.signal
|
@@ -1256,3 +1257,90 @@ def test_undefined_fft_length_and_last_dimension(self):
|
1256 | 1257 | expected_shape = real_part.shape[:-1] + (None,)
|
1257 | 1258 |
|
1258 | 1259 | self.assertEqual(output_spec.shape, expected_shape)
|
| 1260 | + |
| 1261 | + |
| 1262 | +class TestMathErrors(testing.TestCase): |
| 1263 | + |
| 1264 | + @pytest.mark.skipif( |
| 1265 | + backend.backend() != "jax", reason="Testing Jax errors only" |
| 1266 | + ) |
| 1267 | + def test_segment_sum_no_num_segments(self): |
| 1268 | + data = jnp.array([1, 2, 3, 4]) |
| 1269 | + segment_ids = jnp.array([0, 0, 1, 1]) |
| 1270 | + with self.assertRaisesRegex( |
| 1271 | + ValueError, |
| 1272 | + "Argument `num_segments` must be set when using the JAX backend.", |
| 1273 | + ): |
| 1274 | + kmath.segment_sum(data, segment_ids) |
| 1275 | + |
| 1276 | + @pytest.mark.skipif( |
| 1277 | + backend.backend() != "jax", reason="Testing Jax errors only" |
| 1278 | + ) |
| 1279 | + def test_segment_max_no_num_segments(self): |
| 1280 | + data = jnp.array([1, 2, 3, 4]) |
| 1281 | + segment_ids = jnp.array([0, 0, 1, 1]) |
| 1282 | + with self.assertRaisesRegex( |
| 1283 | + ValueError, |
| 1284 | + "Argument `num_segments` must be set when using the JAX backend.", |
| 1285 | + ): |
| 1286 | + kmath.segment_max(data, segment_ids) |
| 1287 | + |
| 1288 | + def test_stft_invalid_input_type(self): |
| 1289 | + # backend agnostic error message |
| 1290 | + x = np.array([1, 2, 3, 4]) |
| 1291 | + sequence_length = 2 |
| 1292 | + sequence_stride = 1 |
| 1293 | + fft_length = 4 |
| 1294 | + with self.assertRaisesRegex(TypeError, "`float32` or `float64`"): |
| 1295 | + kmath.stft(x, sequence_length, sequence_stride, fft_length) |
| 1296 | + |
| 1297 | + def test_invalid_fft_length(self): |
| 1298 | + # backend agnostic error message |
| 1299 | + x = np.array([1.0, 2.0, 3.0, 4.0]) |
| 1300 | + sequence_length = 4 |
| 1301 | + sequence_stride = 1 |
| 1302 | + fft_length = 2 |
| 1303 | + with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"): |
| 1304 | + kmath.stft(x, sequence_length, sequence_stride, fft_length) |
| 1305 | + |
| 1306 | + def test_stft_invalid_window(self): |
| 1307 | + # backend agnostic error message |
| 1308 | + x = np.array([1.0, 2.0, 3.0, 4.0]) |
| 1309 | + sequence_length = 2 |
| 1310 | + sequence_stride = 1 |
| 1311 | + fft_length = 4 |
| 1312 | + window = "invalid_window" |
| 1313 | + with self.assertRaisesRegex(ValueError, "If a string is passed to"): |
| 1314 | + kmath.stft( |
| 1315 | + x, sequence_length, sequence_stride, fft_length, window=window |
| 1316 | + ) |
| 1317 | + |
| 1318 | + def test_stft_invalid_window_shape(self): |
| 1319 | + # backend agnostic error message |
| 1320 | + x = np.array([1.0, 2.0, 3.0, 4.0]) |
| 1321 | + sequence_length = 2 |
| 1322 | + sequence_stride = 1 |
| 1323 | + fft_length = 4 |
| 1324 | + window = np.ones((sequence_length + 1)) |
| 1325 | + with self.assertRaisesRegex(ValueError, "The shape of `window` must"): |
| 1326 | + kmath.stft( |
| 1327 | + x, sequence_length, sequence_stride, fft_length, window=window |
| 1328 | + ) |
| 1329 | + |
| 1330 | + def test_istft_invalid_window_shape_2D_inputs(self): |
| 1331 | + # backend agnostic error message |
| 1332 | + x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]])) |
| 1333 | + sequence_length = 2 |
| 1334 | + sequence_stride = 1 |
| 1335 | + fft_length = 4 |
| 1336 | + incorrect_window = np.ones((sequence_length + 1,)) |
| 1337 | + with self.assertRaisesRegex( |
| 1338 | + ValueError, "The shape of `window` must be equal to" |
| 1339 | + ): |
| 1340 | + kmath.istft( |
| 1341 | + x, |
| 1342 | + sequence_length, |
| 1343 | + sequence_stride, |
| 1344 | + fft_length, |
| 1345 | + window=incorrect_window, |
| 1346 | + ) |
0 commit comments