Skip to content

Commit cb45757

Browse files
committed
minor edits
1 parent ed7cbf0 commit cb45757

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

composer/datasets/c4.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,26 +256,30 @@ def _subsample(self, device_offset, text_batch):
256256
return text_batch
257257

258258
def _shard_dataset(self, dataset):
259-
# Select a deterministic subset of filepaths for sharded data-parallel training
259+
# Verify # of shards
260260
filepaths = dataset._ex_iterable.kwargs['filepaths']
261261
if self.num_shards != len(filepaths):
262262
raise ValueError(f"Found {len(filepaths)} shards, expected {self.num_shards}")
263263

264+
# Determine how to allocate devices to shards
264265
devices_per_shard = 1
265-
if self.world_size > self.num_shards:
266+
if self.num_shards < self.world_size:
266267
log.warning(
267268
f"Not enough unique shards ({self.num_shards}) for world size ({self.world_size}). Splitting shards among devices."
268269
)
269270
if self.world_size % self.num_shards != 0:
270271
raise ValueError(f"Cannot evenly split {self.num_shards} shards among {self.world_size} devices")
271272
devices_per_shard = self.world_size // self.num_shards
273+
elif self.num_shards % self.world_size != 0:
274+
raise ValueError(f"Cannot evenly split {self.num_shards} shards among {self.world_size} devices")
272275
shard_offset = self.rank // devices_per_shard
273276
device_offset = self.rank % devices_per_shard
274277

278+
# Select a deterministic subset of shards
275279
device_filepaths = filepaths[shard_offset::self.world_size]
276280
dataset._ex_iterable.kwargs['filepaths'] = device_filepaths
277281

278-
# Subsample dataset if shard is being shared among devices
282+
# Subsample shard if shard is being shared among devices
279283
# NOTE: Mapping is executed in batched mode for better CPU utilization,
280284
# but the returned dataset is still an iterable over text samples
281285
if devices_per_shard > 1:

0 commit comments

Comments
 (0)