14
14
import yahp as hp
15
15
from transformers .testing_utils import CaptureLogger
16
16
17
- from composer .core .types import Batch
17
+ from composer .core .types import Batch , DataSpec
18
18
from composer .datasets .dataloader import DataloaderHparams
19
- from composer .datasets .hparams import DataloaderSpec , DatasetHparams
19
+ from composer .datasets .hparams import DatasetHparams
20
20
from composer .utils import dist
21
21
from composer .utils .data import get_subset_dataset
22
22
@@ -73,18 +73,18 @@ def _load_dataset(self):
73
73
split = self .split ,
74
74
streaming = True )
75
75
76
- def _get_approx_num_samples (self ):
76
+ def _get_approx_num_samples_per_device (self ):
77
77
try :
78
78
if self .max_samples > 0 :
79
- return self .max_samples
79
+ return self .max_samples // dist . get_world_size ()
80
80
else :
81
81
n_shards , samples_per_shard = CACHED_DATASET_SIZES [self .dataset_name ][self .dataset_config_name ][self .split ]
82
82
n_shards = self .max_shards if self .max_shards > 0 else n_shards
83
- return n_shards * samples_per_shard
83
+ return n_shards * samples_per_shard // dist . get_world_size ()
84
84
except :
85
85
raise NotImplementedError
86
86
87
- def _get_approx_num_tokens (self ):
87
+ def _get_approx_num_tokens_per_device (self ):
88
88
return 1e12
89
89
90
90
def _subsample (self , device_offset , text_batch ):
@@ -166,7 +166,7 @@ def _group_tokens(self, token_batch):
166
166
else :
167
167
raise ValueError (f"Unknown group_method: '{ group_method } '" )
168
168
169
- def initialize_object (self , batch_size : int , dataloader_hparams : DataloaderHparams ) -> DataloaderSpec :
169
+ def initialize_object (self , batch_size : int , dataloader_hparams : DataloaderHparams ) -> DataSpec :
170
170
assert dataloader_hparams .num_workers == 1 , "LM Streaming Dataloader only supports num_workers=1"
171
171
172
172
try :
@@ -209,13 +209,12 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
209
209
batch_size = token_sample_batch_size ,
210
210
)
211
211
212
- # Maybe limit the number of post-processed samples
213
- if self .max_samples > 0 :
214
- token_dataset = token_dataset .take (self .max_samples // dist .get_world_size ())
215
-
216
- # Add approx num samples and create a SizedIterableDataset
217
- sized_iterable_dataset = SizedIterableDataset (token_dataset , self ._get_approx_num_samples ())
212
+ # Limit the number of post-processed samples
213
+ num_samples_per_device = self ._get_approx_num_samples_per_device ()
214
+ token_dataset = token_dataset .take (num_samples_per_device )
218
215
216
+ # HACK: create a SizedIterableDataset
217
+ sized_iterable_dataset = SizedIterableDataset (token_dataset , num_samples_per_device )
219
218
220
219
# Get collate_fn
221
220
if self .tokenizer_name in ["gpt2" ]:
@@ -225,25 +224,25 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
225
224
collate_fn = transformers .DataCollatorForLanguageModeling (tokenizer = self .tokenizer ,
226
225
mlm = self .use_masked_lm ,
227
226
mlm_probability = self .mlm_probability )
228
- # Return DataloaderSpec
229
- return DataloaderSpec (dataloader = dataloader_hparams .initialize_object (
227
+ # Return DataSpec
228
+ return DataSpec (dataloader = dataloader_hparams .initialize_object (
230
229
dataset = sized_iterable_dataset ,
231
230
batch_size = batch_size ,
232
231
sampler = None ,
233
232
drop_last = self .drop_last ,
234
233
collate_fn = collate_fn ,
235
234
),
236
- split_fn = _split_dict_fn )
235
+ split_batch = _split_dict_fn )
237
236
238
237
239
238
class SizedIterableDataset (torch .utils .data .IterableDataset ):
240
239
241
- def __init__ (self , hf_iterable_dataset , num_samples ):
240
+ def __init__ (self , hf_iterable_dataset , num_samples_per_device ):
242
241
self .hf_iterable_dataset = hf_iterable_dataset
243
- self .num_samples = num_samples
242
+ self .num_samples_per_device = num_samples_per_device
244
243
245
244
def __iter__ (self ):
246
245
return iter (self .hf_iterable_dataset )
247
246
248
247
def __len__ (self ):
249
- return self .num_samples
248
+ return self .num_samples_per_device
0 commit comments