1
1
import itertools
2
2
3
- from keras import backend
4
3
from keras .trainers .data_adapters import data_adapter_utils
5
4
from keras .trainers .data_adapters .data_adapter import DataAdapter
6
5
from keras .utils import tree
@@ -10,49 +9,19 @@ class GeneratorDataAdapter(DataAdapter):
10
9
"""Adapter for Python generators."""
11
10
12
11
def __init__ (self , generator ):
13
- first_batch , generator = peek_and_restore (generator )
12
+ first_batches , generator = peek_and_restore (generator )
14
13
self .generator = generator
15
- self ._first_batch = first_batch
14
+ self ._first_batches = first_batches
16
15
self ._output_signature = None
17
- if not isinstance (first_batch , tuple ):
16
+ if not isinstance (first_batches [ 0 ] , tuple ):
18
17
raise ValueError (
19
18
"When passing a Python generator to a Keras model, "
20
19
"the generator must return a tuple, either "
21
20
"(input,) or (inputs, targets) or "
22
21
"(inputs, targets, sample_weights). "
23
- f"Received: { first_batch } "
22
+ f"Received: { first_batches [ 0 ] } "
24
23
)
25
24
26
- def _set_tf_output_signature (self ):
27
- from keras .utils .module_utils import tensorflow as tf
28
-
29
- def get_tensor_spec (x ):
30
- shape = x .shape
31
- if len (shape ) < 1 :
32
- raise ValueError (
33
- "When passing a Python generator to a Keras model, "
34
- "the arrays returned by the generator "
35
- "must be at least rank 1. Received: "
36
- f"{ x } of rank { len (x .shape )} "
37
- )
38
- shape = list (shape )
39
- shape [0 ] = None # The batch size is not guaranteed to be static.
40
- dtype = backend .standardize_dtype (x .dtype )
41
- if isinstance (x , tf .RaggedTensor ):
42
- return tf .RaggedTensorSpec (shape = shape , dtype = dtype )
43
- if (
44
- isinstance (x , tf .SparseTensor )
45
- or data_adapter_utils .is_scipy_sparse (x )
46
- or data_adapter_utils .is_jax_sparse (x )
47
- ):
48
- return tf .SparseTensorSpec (shape = shape , dtype = dtype )
49
- else :
50
- return tf .TensorSpec (shape = shape , dtype = dtype )
51
-
52
- self ._output_signature = tree .map_structure (
53
- get_tensor_spec , self ._first_batch
54
- )
55
-
56
25
def get_numpy_iterator (self ):
57
26
return data_adapter_utils .get_numpy_iterator (self .generator )
58
27
@@ -85,7 +54,9 @@ def get_tf_iterator():
85
54
yield batch
86
55
87
56
if self ._output_signature is None :
88
- self ._set_tf_output_signature ()
57
+ self ._output_signature = data_adapter_utils .get_tensor_spec (
58
+ self ._first_batches
59
+ )
89
60
ds = tf .data .Dataset .from_generator (
90
61
get_tf_iterator ,
91
62
output_signature = self ._output_signature ,
@@ -106,5 +77,9 @@ def batch_size(self):
106
77
107
78
108
79
def peek_and_restore (generator ):
109
- element = next (generator )
110
- return element , itertools .chain ([element ], generator )
80
+ batches = list (
81
+ itertools .islice (
82
+ generator , data_adapter_utils .NUM_BATCHES_FOR_TENSOR_SPEC
83
+ )
84
+ )
85
+ return batches , itertools .chain (batches , generator )
0 commit comments