Skip to content

Commit 6d8498c

Browse files
Fixed issue #135; rename total_batch_size to train_batch_size (#137)
1. Remove the `subset_num_batches` from the dataset hparams. Synthetic datasets should instead use the length of the real dataset as the size, or have a configurable size 2. Add `train_subset_num_batches` and `eval_subset_num_batches` to the trainer hparams 3. Add a check in the trainer that ensures that, if this field is set, then `DatasetHparams.shuffle is False`, or otherwise emit a warning that every epoch may be using a different subset of samples 4. Renamed `total_batch_size` to `train_batch_size`. Updated hparams.
1 parent 383b62a commit 6d8498c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+183
-233
lines changed

composer/core/state.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"eval_dataloader",
5757
"precision",
5858
"precision_context",
59+
"_steps_per_epoch",
5960
]
6061

6162

@@ -116,6 +117,7 @@ class State(Serializable):
116117
# but the getter will always return a Precision enum
117118
precision: Union[str, types.Precision] # type: ignore
118119
_precision: types.Precision = field(init=False) # but store an enum internally
120+
_steps_per_epoch: Optional[int] = field(init=False, default=None)
119121
precision_context: Callable[[Union[str, Precision]], ContextManager] = \
120122
field(default_factory=default_precision_factory)
121123

@@ -210,9 +212,13 @@ def batch_idx(self) -> int:
210212
@property
211213
def steps_per_epoch(self) -> int:
212214
"""int: The number of steps (batches) per epoch."""
213-
if self.train_dataloader is None:
214-
raise RuntimeError("To determine the number of steps per epoch, state.train_dataloader must be set.")
215-
return len(self.train_dataloader)
215+
if self._steps_per_epoch is None:
216+
return len(self.train_dataloader)
217+
return self._steps_per_epoch
218+
219+
@steps_per_epoch.setter
220+
def steps_per_epoch(self, val: Optional[int]): # type: ignore
221+
self._steps_per_epoch = val
216222

217223
@property
218224
def precision(self) -> types.Precision:

composer/datasets/brats.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from composer.datasets.dataloader import DataloaderHparams
1616
from composer.datasets.hparams import DatasetHparams
1717
from composer.utils import ddp
18-
from composer.utils.data import get_subset_dataset
1918

2019
PATCH_SIZE = [1, 192, 160]
2120

@@ -48,9 +47,6 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
4847
raise ValueError("datadir must be specified if self.synthetic is False")
4948
x_train, y_train, x_val, y_val = get_data_split(self.datadir)
5049
dataset = PytTrain(x_train, y_train, oversampling) if self.is_train else PytVal(x_val, y_val)
51-
if self.subset_num_batches is not None:
52-
size = batch_size * self.subset_num_batches * ddp.get_world_size()
53-
dataset = get_subset_dataset(size, dataset)
5450
collate_fn = None if self.is_train else _my_collate
5551
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
5652

composer/datasets/cifar10.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from dataclasses import dataclass
44

5-
import torch.utils.data
65
import yahp as hp
76
from torchvision import transforms
87
from torchvision.datasets import CIFAR10
@@ -12,7 +11,6 @@
1211
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin
1312
from composer.datasets.synthetic import SyntheticBatchPairDataset
1413
from composer.utils import ddp
15-
from composer.utils.data import get_subset_dataset
1614

1715

1816
@dataclass
@@ -28,20 +26,15 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
2826
cifar10_mean, cifar10_std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
2927

3028
if self.use_synthetic:
31-
if self.subset_num_batches is None:
32-
raise ValueError("subset_num_batches is required if use_synthetic is True")
29+
total_dataset_size = 50_000 if self.is_train else 10_000
3330
dataset = SyntheticBatchPairDataset(
34-
total_dataset_size=self.subset_num_batches * batch_size,
31+
total_dataset_size=total_dataset_size,
3532
data_shape=[3, 32, 32],
3633
num_classes=10,
3734
num_unique_samples_to_create=self.synthetic_num_unique_samples,
3835
device=self.synthetic_device,
3936
memory_format=self.synthetic_memory_format,
4037
)
41-
if self.shuffle:
42-
sampler = torch.utils.data.RandomSampler(dataset)
43-
else:
44-
sampler = torch.utils.data.SequentialSampler(dataset)
4538

4639
else:
4740
if self.datadir is None:
@@ -66,10 +59,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
6659
download=self.download,
6760
transform=transformation,
6861
)
69-
if self.subset_num_batches is not None:
70-
size = batch_size * self.subset_num_batches * ddp.get_world_size()
71-
dataset = get_subset_dataset(size, dataset)
72-
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
62+
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
7363

7464
return dataloader_hparams.initialize_object(dataset,
7565
batch_size=batch_size,

composer/datasets/hparams.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,6 @@ class DatasetHparams(hp.Hparams, abc.ABC, metaclass=metaclass):
9090
If the number of samples is not divisible by the batch size, whether
9191
to drop the last batch (the default) or pad the last batch with zeros.
9292
shuffle (bool): Whether to shuffle the dataset. Defaults to True.
93-
subset_num_batches (int, optional): If specified, limit the number of batches per dataloader iteration.
94-
Specifically, ``len(dataloader) == num_total_batches``, where the ``dataloader`` is returned via
95-
:meth:`initialize_object`. Each epoch should yield the same subset of samples.
96-
97-
If this value is greater than the total number of samples in the dataset, then a :class:`ValueError`
98-
is raised.
99-
100-
If None (the default), then the entire dataset will be iterated over.
10193
"""
10294

10395
is_train: bool = hp.optional("Whether to load the training data (the default) or validation data.", default=True)
@@ -106,9 +98,6 @@ class DatasetHparams(hp.Hparams, abc.ABC, metaclass=metaclass):
10698
default=True)
10799
shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch. Defaults to True.", default=True)
108100

109-
subset_num_batches: Optional[int] = hp.optional(
110-
"If not None, limit len(dataloader) to this many batches. If None (the default), then the dataloader will iterate over the entire dataset.",
111-
default=None)
112101
datadir: Optional[str] = hp.optional("The path to the data directory", default=None)
113102

114103
@abc.abstractmethod

composer/datasets/imagenet.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from composer.datasets.hparams import DataloaderSpec, DatasetHparams, SyntheticHparamsMixin
1818
from composer.datasets.synthetic import SyntheticBatchPairDataset
1919
from composer.utils import ddp
20-
from composer.utils.data import get_subset_dataset
2120

2221

2322
class TransformationFn:
@@ -80,10 +79,9 @@ class ImagenetDatasetHparams(DatasetHparams, SyntheticHparamsMixin):
8079
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataloaderSpec:
8180

8281
if self.use_synthetic:
83-
if self.subset_num_batches is None:
84-
raise ValueError("subset_num_batches is required if use_synthetic is True")
82+
total_dataset_size = 1_281_167 if self.is_train else 50_000
8583
dataset = SyntheticBatchPairDataset(
86-
total_dataset_size=self.subset_num_batches * batch_size * ddp.get_world_size(),
84+
total_dataset_size=total_dataset_size,
8785
data_shape=[3, self.crop_size, self.crop_size],
8886
num_classes=1000,
8987
num_unique_samples_to_create=self.synthetic_num_unique_samples,
@@ -92,10 +90,6 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
9290
)
9391
collate_fn = None
9492
device_transform_fn = None
95-
if self.shuffle:
96-
sampler = torch.utils.data.RandomSampler(dataset)
97-
else:
98-
sampler = torch.utils.data.SequentialSampler(dataset)
9993
else:
10094

10195
if self.is_train:
@@ -125,10 +119,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
125119
if self.datadir is None:
126120
raise ValueError("datadir must be specified is self.synthetic is False")
127121
dataset = ImageFolder(os.path.join(self.datadir, split), transformation)
128-
if self.subset_num_batches is not None:
129-
size = batch_size * self.subset_num_batches * ddp.get_world_size()
130-
dataset = get_subset_dataset(size, dataset)
131-
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
122+
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
132123

133124
return DataloaderSpec(dataloader=dataloader_hparams.initialize_object(
134125
dataset=dataset,

composer/datasets/lm_datasets.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from composer.datasets.dataloader import DataloaderHparams
1313
from composer.datasets.hparams import DataloaderSpec, DatasetHparams
1414
from composer.utils import ddp
15-
from composer.utils.data import get_subset_dataset
1615

1716
log = logging.getLogger(__name__)
1817

@@ -98,10 +97,6 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
9897
log.info(f"Total number of samples: {num_samples:e}")
9998
log.info(f"Total number of tokens: {self.num_tokens:e}")
10099
dataset = lm_datasets
101-
if self.subset_num_batches is not None:
102-
size = batch_size * self.subset_num_batches * ddp.get_world_size()
103-
dataset = get_subset_dataset(size, dataset)
104-
105100
data_collator = transformers.default_data_collator
106101

107102
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

composer/datasets/mnist.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from dataclasses import dataclass
44

5-
import torch.utils.data
65
import yahp as hp
76
from torchvision import datasets, transforms
87

@@ -11,7 +10,6 @@
1110
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin
1211
from composer.datasets.synthetic import SyntheticBatchPairDataset
1312
from composer.utils import ddp
14-
from composer.utils.data import get_subset_dataset
1513

1614

1715
@dataclass
@@ -25,20 +23,14 @@ class MNISTDatasetHparams(DatasetHparams, SyntheticHparamsMixin):
2523

2624
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataLoader:
2725
if self.use_synthetic:
28-
if self.subset_num_batches is None:
29-
raise ValueError("subset_num_batches is required if use_synthetic is True")
3026
dataset = SyntheticBatchPairDataset(
31-
total_dataset_size=self.subset_num_batches * batch_size,
27+
total_dataset_size=60_000 if self.is_train else 10_000,
3228
data_shape=[1, 28, 28],
3329
num_classes=10,
3430
num_unique_samples_to_create=self.synthetic_num_unique_samples,
3531
device=self.synthetic_device,
3632
memory_format=self.synthetic_memory_format,
3733
)
38-
if self.shuffle:
39-
sampler = torch.utils.data.RandomSampler(dataset)
40-
else:
41-
sampler = torch.utils.data.SequentialSampler(dataset)
4234

4335
else:
4436
if self.datadir is None:
@@ -51,10 +43,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
5143
download=self.download,
5244
transform=transform,
5345
)
54-
if self.subset_num_batches is not None:
55-
size = batch_size * self.subset_num_batches * ddp.get_world_size()
56-
dataset = get_subset_dataset(size, dataset)
57-
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
46+
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
5847
return dataloader_hparams.initialize_object(dataset=dataset,
5948
batch_size=batch_size,
6049
sampler=sampler,

composer/loggers/tqdm_logger.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def init(self, state: State, logger: Logger) -> None:
117117

118118
def _start(self, state: State):
119119
assert self.is_train is not None, "self.is_train should be set by the callback"
120-
total_steps = len(state.train_dataloader) if self.is_train else len(state.eval_dataloader)
120+
# TODO(anis) -- in #120, len(state.eval_dataloader) is inaccurate, as it does not incorporate
121+
# trainer._eval_subset_num_batches. The evaluator spec should fix this.
122+
total_steps = state.steps_per_epoch if self.is_train else len(state.eval_dataloader)
121123
self.pbars[self.is_train] = _TQDMLoggerInstance(total=total_steps, epoch=state.epoch, is_train=self.is_train)
122124

123125
def epoch_start(self, state: State, logger: Logger) -> None:

composer/models/gpt2/scaling_laws_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def parse_args():
173173
},
174174
],
175175
'max_epochs': 1,
176-
'total_batch_size': 8,
176+
'train_batch_size': 8,
177177
'eval_batch_size': 8,
178178
'seed': 17,
179179
'accelerator': {
@@ -320,7 +320,7 @@ def configure_mosaic_yaml(model, scaling_law_predictions):
320320
logger.info(f"Minumum possible serial optimization steps before SSR: {min_serial_steps:,}")
321321
logger.info(f"Minumum possible serial optimization steps after SSR: {math.ceil(args.ssr * min_serial_steps):,}")
322322
logger.info(f"Current serial optimization steps: {final_serial_steps:,}")
323-
template_yaml['total_batch_size'] = batch_size
323+
template_yaml['train_batch_size'] = batch_size
324324
assert math.floor(batch_size / curr_grad_accum) == (batch_size / curr_grad_accum)
325325
template_yaml['eval_batch_size'] = math.floor(batch_size / curr_grad_accum)
326326
template_yaml['grad_accum'] = curr_grad_accum

composer/models/unet/hparams.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ model:
2929
loggers:
3030
- tqdm: {}
3131
max_epochs: 200
32-
total_batch_size: 64
32+
train_batch_size: 64
3333
eval_batch_size: 8
3434
seed: 0
3535
validate_every_n_epochs: 1

0 commit comments

Comments
 (0)