Skip to content

Support Infinite PyDataset #19528

@LarsKue

Description

@LarsKue

Toy datasets are often generated on the fly, making their effective size infinite. Model.fit() nicely supports this idea by allowing to pass a batch_size and steps_per_epoch parameter. However, we already need both of these in a fully generated PyDataset since we need to return a positive length and return batches from __getitem__. This leads to code like this:

dataset = MyPyDataset(batch_size=128, steps_per_epoch=100)
model.fit(dataset, batch_size=128, steps_per_epoch=100)

which creates redundancy, because the parameters to fit() are ignored, or both must have the same value.

Conceptually, it makes the most sense to pass this information in model.fit(), so I propose two changes to PyDataset:

  1. Allow signalling a dataset of infinite length, e.g. by returning -1 or raising a TypeError like a tf.data.Dataset
  2. Allow fetching dynamically sized batches in __getitem__, e.g. by returning only single items instead of a batch, like a torch.utils.data.Dataset

I would be willing to help out on a pull request for this, but would appreciate some initial pointers and whether these changes could see merging. Thank you.

Metadata

Metadata

Labels

stat:awaiting keras-engAwaiting response from Keras engineertype:featureThe user is asking for a new feature.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions