@@ -256,26 +256,30 @@ def _subsample(self, device_offset, text_batch):
256
256
return text_batch
257
257
258
258
def _shard_dataset (self , dataset ):
259
- # Select a deterministic subset of filepaths for sharded data-parallel training
259
+ # Verify # of shards
260
260
filepaths = dataset ._ex_iterable .kwargs ['filepaths' ]
261
261
if self .num_shards != len (filepaths ):
262
262
raise ValueError (f"Found { len (filepaths )} shards, expected { self .num_shards } " )
263
263
264
+ # Determine how to allocate devices to shards
264
265
devices_per_shard = 1
265
- if self .world_size > self .num_shards :
266
+ if self .num_shards < self .world_size :
266
267
log .warning (
267
268
f"Not enough unique shards ({ self .num_shards } ) for world size ({ self .world_size } ). Splitting shards among devices."
268
269
)
269
270
if self .world_size % self .num_shards != 0 :
270
271
raise ValueError (f"Cannot evenly split { self .num_shards } shards among { self .world_size } devices" )
271
272
devices_per_shard = self .world_size // self .num_shards
273
+ elif self .num_shards % self .world_size != 0 :
274
+ raise ValueError (f"Cannot evenly split { self .num_shards } shards among { self .world_size } devices" )
272
275
shard_offset = self .rank // devices_per_shard
273
276
device_offset = self .rank % devices_per_shard
274
277
278
+ # Select a deterministic subset of shards
275
279
device_filepaths = filepaths [shard_offset ::self .world_size ]
276
280
dataset ._ex_iterable .kwargs ['filepaths' ] = device_filepaths
277
281
278
- # Subsample dataset if shard is being shared among devices
282
+ # Subsample shard if shard is being shared among devices
279
283
# NOTE: Mapping is executed in batched mode for better CPU utilization,
280
284
# but the returned dataset is still an iterable over text samples
281
285
if devices_per_shard > 1 :
0 commit comments