-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Closed
Labels
stat:awaiting keras-engAwaiting response from Keras engineerAwaiting response from Keras engineertype:featureThe user is asking for a new feature.The user is asking for a new feature.
Description
Toy datasets are often generated on the fly, making their effective size infinite. Model.fit()
nicely supports this idea by allowing to pass a batch_size
and steps_per_epoch
parameter. However, we already need both of these in a fully generated PyDataset
since we need to return a positive length and return batches from __getitem__
. This leads to code like this:
dataset = MyPyDataset(batch_size=128, steps_per_epoch=100)
model.fit(dataset, batch_size=128, steps_per_epoch=100)
which creates redundancy, because the parameters to fit()
are ignored, or both must have the same value.
Conceptually, it makes the most sense to pass this information in model.fit()
, so I propose two changes to PyDataset
:
- Allow signalling a dataset of infinite length, e.g. by returning -1 or raising a TypeError like a
tf.data.Dataset
- Allow fetching dynamically sized batches in
__getitem__
, e.g. by returning only single items instead of a batch, like atorch.utils.data.Dataset
I would be willing to help out on a pull request for this, but would appreciate some initial pointers and whether these changes could see merging. Thank you.
Metadata
Metadata
Assignees
Labels
stat:awaiting keras-engAwaiting response from Keras engineerAwaiting response from Keras engineertype:featureThe user is asking for a new feature.The user is asking for a new feature.