Skip to content

Commit 838c7da

Browse files
committed
Merge branch 'master' of github.com:keras-team/keras
2 parents c1dfba3 + f3a01a7 commit 838c7da

File tree

8 files changed

+216
-87
lines changed

8 files changed

+216
-87
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ jobs:
8888
env_vars: PYTHON,KERAS_HOME
8989
flags: keras,keras-${{ matrix.backend }}
9090
files: core-coverage.xml
91+
token: ${{ secrets.CODECOV_TOKEN }}
9192
fail_ci_if_error: false
9293

9394
format:

keras/trainers/data_adapters/data_adapter_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from keras.api_export import keras_export
55
from keras.utils import tree
66

7+
NUM_BATCHES_FOR_TENSOR_SPEC = 2
8+
79

810
@keras_export("keras.utils.unpack_x_y_sample_weight")
911
def unpack_x_y_sample_weight(data):
@@ -125,6 +127,54 @@ def class_weight_to_sample_weights(y, class_weight):
125127
return sample_weight
126128

127129

130+
def get_tensor_spec(batches):
131+
"""Return the common tensor spec for a list of batches.
132+
133+
Args:
134+
batches: list of structures of tensors. The structures must be
135+
identical, but the shape at each leaf may be different.
136+
Returns: the common tensor spec for all the batches.
137+
"""
138+
from keras.utils.module_utils import tensorflow as tf
139+
140+
def get_single_tensor_spec(*tensors):
141+
x = tensors[0]
142+
rank = len(x.shape)
143+
if rank < 1:
144+
raise ValueError(
145+
"When passing a dataset to a Keras model, the arrays must "
146+
f"be at least rank 1. Received: {x} of rank {len(x.shape)}."
147+
)
148+
for t in tensors:
149+
if len(t.shape) != rank:
150+
raise ValueError(
151+
"When passing a dataset to a Keras model, the "
152+
"corresponding arrays in each batch must have the same "
153+
f"rank. Received: {x} and {t}"
154+
)
155+
shape = []
156+
# Merge shapes: go through each dimension one by one and keep the
157+
# common values
158+
for dims in zip(*[list(x.shape) for x in tensors]):
159+
dims_set = set(dims)
160+
shape.append(dims_set.pop() if len(dims_set) == 1 else None)
161+
shape[0] = None # batch size may not be static
162+
163+
dtype = backend.standardize_dtype(x.dtype)
164+
if isinstance(x, tf.RaggedTensor):
165+
return tf.RaggedTensorSpec(shape=shape, dtype=dtype)
166+
if (
167+
isinstance(x, tf.SparseTensor)
168+
or is_scipy_sparse(x)
169+
or is_jax_sparse(x)
170+
):
171+
return tf.SparseTensorSpec(shape=shape, dtype=dtype)
172+
else:
173+
return tf.TensorSpec(shape=shape, dtype=dtype)
174+
175+
return tree.map_structure(get_single_tensor_spec, *batches)
176+
177+
128178
def get_jax_iterator(iterable):
129179
from keras.backend.jax.core import convert_to_tensor
130180

keras/trainers/data_adapters/generator_data_adapter.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import itertools
22

3-
from keras import backend
43
from keras.trainers.data_adapters import data_adapter_utils
54
from keras.trainers.data_adapters.data_adapter import DataAdapter
65
from keras.utils import tree
@@ -10,49 +9,19 @@ class GeneratorDataAdapter(DataAdapter):
109
"""Adapter for Python generators."""
1110

1211
def __init__(self, generator):
13-
first_batch, generator = peek_and_restore(generator)
12+
first_batches, generator = peek_and_restore(generator)
1413
self.generator = generator
15-
self._first_batch = first_batch
14+
self._first_batches = first_batches
1615
self._output_signature = None
17-
if not isinstance(first_batch, tuple):
16+
if not isinstance(first_batches[0], tuple):
1817
raise ValueError(
1918
"When passing a Python generator to a Keras model, "
2019
"the generator must return a tuple, either "
2120
"(input,) or (inputs, targets) or "
2221
"(inputs, targets, sample_weights). "
23-
f"Received: {first_batch}"
22+
f"Received: {first_batches[0]}"
2423
)
2524

26-
def _set_tf_output_signature(self):
27-
from keras.utils.module_utils import tensorflow as tf
28-
29-
def get_tensor_spec(x):
30-
shape = x.shape
31-
if len(shape) < 1:
32-
raise ValueError(
33-
"When passing a Python generator to a Keras model, "
34-
"the arrays returned by the generator "
35-
"must be at least rank 1. Received: "
36-
f"{x} of rank {len(x.shape)}"
37-
)
38-
shape = list(shape)
39-
shape[0] = None # The batch size is not guaranteed to be static.
40-
dtype = backend.standardize_dtype(x.dtype)
41-
if isinstance(x, tf.RaggedTensor):
42-
return tf.RaggedTensorSpec(shape=shape, dtype=dtype)
43-
if (
44-
isinstance(x, tf.SparseTensor)
45-
or data_adapter_utils.is_scipy_sparse(x)
46-
or data_adapter_utils.is_jax_sparse(x)
47-
):
48-
return tf.SparseTensorSpec(shape=shape, dtype=dtype)
49-
else:
50-
return tf.TensorSpec(shape=shape, dtype=dtype)
51-
52-
self._output_signature = tree.map_structure(
53-
get_tensor_spec, self._first_batch
54-
)
55-
5625
def get_numpy_iterator(self):
5726
return data_adapter_utils.get_numpy_iterator(self.generator)
5827

@@ -85,7 +54,9 @@ def get_tf_iterator():
8554
yield batch
8655

8756
if self._output_signature is None:
88-
self._set_tf_output_signature()
57+
self._output_signature = data_adapter_utils.get_tensor_spec(
58+
self._first_batches
59+
)
8960
ds = tf.data.Dataset.from_generator(
9061
get_tf_iterator,
9162
output_signature=self._output_signature,
@@ -106,5 +77,9 @@ def batch_size(self):
10677

10778

10879
def peek_and_restore(generator):
109-
element = next(generator)
110-
return element, itertools.chain([element], generator)
80+
batches = list(
81+
itertools.islice(
82+
generator, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC
83+
)
84+
)
85+
return batches, itertools.chain(batches, generator)

keras/trainers/data_adapters/generator_data_adapter_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,41 @@ def test_basic_flow(self, use_sample_weight, generator_type, iterator_type):
101101
sample_order.append(by[i, 0])
102102
self.assertAllClose(sample_order, list(range(34)))
103103

104+
@parameterized.named_parameters(
105+
named_product(iterator_type=["np", "tf", "jax", "torch"])
106+
)
107+
def test_with_different_shapes(self, iterator_type):
108+
def generator():
109+
yield np.ones([16, 4], "float32"), np.ones([16, 2], "float32")
110+
yield np.ones([16, 5], "float32"), np.ones([16, 2], "float32")
111+
yield np.ones([2, 6], "float32"), np.ones([2, 2], "float32")
112+
113+
adapter = generator_data_adapter.GeneratorDataAdapter(generator())
114+
115+
if iterator_type == "np":
116+
it = adapter.get_numpy_iterator()
117+
elif iterator_type == "tf":
118+
it = adapter.get_tf_dataset()
119+
elif iterator_type == "jax":
120+
it = adapter.get_jax_iterator()
121+
elif iterator_type == "torch":
122+
it = adapter.get_torch_dataloader()
123+
124+
for i, batch in enumerate(it):
125+
self.assertEqual(len(batch), 2)
126+
bx, by = batch
127+
self.assertEqual(bx.dtype, by.dtype)
128+
self.assertContainsExactSubsequence(str(bx.dtype), "float32")
129+
if i == 0:
130+
self.assertEqual(bx.shape, (16, 4))
131+
self.assertEqual(by.shape, (16, 2))
132+
elif i == 1:
133+
self.assertEqual(bx.shape, (16, 5))
134+
self.assertEqual(by.shape, (16, 2))
135+
else:
136+
self.assertEqual(bx.shape, (2, 6))
137+
self.assertEqual(by.shape, (2, 2))
138+
104139
@parameterized.named_parameters(
105140
named_product(
106141
generator_type=["tf", "jax", "scipy"], iterator_type=["tf", "jax"]

keras/trainers/data_adapters/py_dataset_adapter.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99

1010
import numpy as np
1111

12-
from keras import backend
1312
from keras.api_export import keras_export
1413
from keras.trainers.data_adapters import data_adapter_utils
1514
from keras.trainers.data_adapters.data_adapter import DataAdapter
16-
from keras.utils import tree
1715

1816

1917
@keras_export(["keras.utils.PyDataset", "keras.utils.Sequence"])
@@ -188,28 +186,6 @@ def __init__(
188186
self.shuffle = shuffle
189187
self._output_signature = None
190188

191-
def _set_tf_output_signature(self):
192-
from keras.utils.module_utils import tensorflow as tf
193-
194-
def get_tensor_spec(x):
195-
shape = x.shape
196-
if len(shape) < 1:
197-
raise ValueError(
198-
"The arrays returned by PyDataset.__getitem__() "
199-
"must be at least rank 1. Received: "
200-
f"{x} of rank {len(x.shape)}"
201-
)
202-
shape = list(shape)
203-
shape[0] = None # The batch size is not guaranteed to be static.
204-
dtype = backend.standardize_dtype(x.dtype)
205-
return tf.TensorSpec(shape=shape, dtype=dtype)
206-
207-
# Grab the first example
208-
batch = self.py_dataset[0]
209-
# Run checks on it and format it
210-
batch = self._standardize_batch(batch)
211-
self._output_signature = tree.map_structure(get_tensor_spec, batch)
212-
213189
def _standardize_batch(self, batch):
214190
if isinstance(batch, dict):
215191
return batch
@@ -287,7 +263,15 @@ def get_tf_dataset(self):
287263
from keras.utils.module_utils import tensorflow as tf
288264

289265
if self._output_signature is None:
290-
self._set_tf_output_signature()
266+
num_samples = min(
267+
data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC,
268+
len(self.py_dataset),
269+
)
270+
batches = [
271+
self._standardize_batch(self.py_dataset[i])
272+
for i in range(num_samples)
273+
]
274+
self._output_signature = data_adapter_utils.get_tensor_spec(batches)
291275

292276
ds = tf.data.Dataset.from_generator(
293277
self._get_iterator,

keras/trainers/data_adapters/py_dataset_adapter_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,54 @@ def test_dict_inputs(self):
233233
self.assertEqual(bx.dtype, by.dtype)
234234
self.assertEqual(tuple(bx.shape), (4, 4))
235235
self.assertEqual(tuple(by.shape), (4, 2))
236+
237+
@parameterized.named_parameters(
238+
named_product(iterator_type=["np", "tf", "jax", "torch"])
239+
)
240+
def test_with_different_shapes(self, iterator_type):
241+
242+
class TestPyDataset(py_dataset_adapter.PyDataset):
243+
def __len__(self):
244+
return 3
245+
246+
def __getitem__(self, idx):
247+
if idx == 0:
248+
return np.ones([16, 4], "float32"), np.ones(
249+
[16, 2], "float32"
250+
)
251+
if idx == 1:
252+
return np.ones([16, 5], "float32"), np.ones(
253+
[16, 2], "float32"
254+
)
255+
else:
256+
return np.ones([2, 6], "float32"), np.ones(
257+
[2, 2], "float32"
258+
)
259+
260+
adapter = py_dataset_adapter.PyDatasetAdapter(
261+
TestPyDataset(), shuffle=False
262+
)
263+
264+
if iterator_type == "np":
265+
it = adapter.get_numpy_iterator()
266+
elif iterator_type == "tf":
267+
it = adapter.get_tf_dataset()
268+
elif iterator_type == "jax":
269+
it = adapter.get_jax_iterator()
270+
elif iterator_type == "torch":
271+
it = adapter.get_torch_dataloader()
272+
273+
for i, batch in enumerate(it):
274+
self.assertEqual(len(batch), 2)
275+
bx, by = batch
276+
self.assertEqual(bx.dtype, by.dtype)
277+
self.assertContainsExactSubsequence(str(bx.dtype), "float32")
278+
if i == 0:
279+
self.assertEqual(bx.shape, (16, 4))
280+
self.assertEqual(by.shape, (16, 2))
281+
elif i == 1:
282+
self.assertEqual(bx.shape, (16, 5))
283+
self.assertEqual(by.shape, (16, 2))
284+
else:
285+
self.assertEqual(bx.shape, (2, 6))
286+
self.assertEqual(by.shape, (2, 2))

keras/trainers/data_adapters/torch_data_loader_adapter.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import itertools
2+
13
import numpy as np
24

3-
from keras import backend
45
from keras.trainers.data_adapters import data_adapter_utils
56
from keras.trainers.data_adapters.data_adapter import DataAdapter
67
from keras.utils import tree
@@ -19,6 +20,7 @@ def __init__(self, dataloader):
1920
)
2021

2122
self._dataloader = dataloader
23+
self._output_signature = None
2224
self._batch_size = dataloader.batch_size
2325
self._num_batches = None
2426
self._partial_batch_size = None
@@ -44,36 +46,24 @@ def get_jax_iterator(self):
4446
def get_tf_dataset(self):
4547
from keras.utils.module_utils import tensorflow as tf
4648

47-
output_signature = self.peek_and_get_tensor_spec()
49+
if self._output_signature is None:
50+
batches = list(
51+
itertools.islice(
52+
self._dataloader,
53+
data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC,
54+
)
55+
)
56+
self._output_signature = tuple(
57+
data_adapter_utils.get_tensor_spec(batches)
58+
)
4859
return tf.data.Dataset.from_generator(
4960
self.get_numpy_iterator,
50-
output_signature=output_signature,
61+
output_signature=self._output_signature,
5162
)
5263

5364
def get_torch_dataloader(self):
5465
return self._dataloader
5566

56-
def peek_and_get_tensor_spec(self):
57-
from keras.utils.module_utils import tensorflow as tf
58-
59-
batch_data = next(iter(self._dataloader))
60-
61-
def get_tensor_spec(x):
62-
shape = x.shape
63-
if len(shape) < 1:
64-
raise ValueError(
65-
"When passing a Pytorch DataLoader to a Keras model, "
66-
"the arrays returned by the generator "
67-
"must be at least rank 1. Received: "
68-
f"{x} of rank {len(x.shape)}"
69-
)
70-
shape = list(shape)
71-
shape[0] = None # The batch size is not guaranteed to be static.
72-
dtype = backend.standardize_dtype(x.dtype)
73-
return tf.TensorSpec(shape=shape, dtype=dtype)
74-
75-
return tuple(tree.map_structure(get_tensor_spec, batch_data))
76-
7767
@property
7868
def num_batches(self):
7969
return self._num_batches

0 commit comments

Comments
 (0)