|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
| 5 | +import warnings |
5 | 6 | from dataclasses import dataclass
|
6 |
| -from typing import Any, Iterator |
| 7 | +from typing import Any, Iterator, Optional |
7 | 8 |
|
8 | 9 | import torch
|
9 | 10 | import torch.distributed
|
10 | 11 | import torch.utils.data
|
11 | 12 | import yahp as hp
|
| 13 | +from torch.utils.data.distributed import DistributedSampler |
12 | 14 | from torch.utils.data.sampler import Sampler
|
13 | 15 |
|
14 | 16 | from composer.core.types import Batch, DataLoader
|
@@ -44,6 +46,44 @@ def __setattr__(self, name: str, value: Any) -> None:
|
44 | 46 | return super().__setattr__(name, value)
|
45 | 47 |
|
46 | 48 |
|
| 49 | +class DDPDataLoader(WrappedDataLoader): |
| 50 | + """Ensure sampler.set_epoch() is called after each iteration. |
| 51 | +
|
| 52 | + DDPDataLoader wraps a dataloader and a distributed sampler and is |
| 53 | + called after each iteration (epoch) through the dataset. |
| 54 | + See: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler |
| 55 | + """ |
| 56 | + |
| 57 | + def __init__(self, dataloader: DataLoader) -> None: |
| 58 | + super().__init__(dataloader) |
| 59 | + if not isinstance(self.dataloader.sampler, DistributedSampler): |
| 60 | + raise ValueError("When using the DDP data loader, the sampler must be a DistributedSampler") |
| 61 | + self._iterator: Optional[Iterator[Batch]] = None |
| 62 | + |
| 63 | + def __iter__(self) -> DDPDataLoader: |
| 64 | + if self._iterator is not None: |
| 65 | + warnings.warn( |
| 66 | + "DataloaderMultipleIterationWarning: " |
| 67 | + "The dataloader detected the start of a new iteration before the previous iteration finished. " |
| 68 | + "The dataloader is skipping ahead to the start of the next epoch. " |
| 69 | + "Multiple simultaneous iterations through the DDP dataloader prohibited, since " |
| 70 | + "it automatically tracks the current epoch.") |
| 71 | + assert isinstance(self.sampler, DistributedSampler) |
| 72 | + self.sampler.set_epoch(epoch=self.sampler.epoch + 1) |
| 73 | + self._iterator = iter(self.dataloader) |
| 74 | + return self |
| 75 | + |
| 76 | + def __next__(self) -> Batch: |
| 77 | + assert self._iterator is not None |
| 78 | + try: |
| 79 | + return next(self._iterator) |
| 80 | + except StopIteration: |
| 81 | + self._iterator = None |
| 82 | + assert isinstance(self.sampler, DistributedSampler) |
| 83 | + self.sampler.set_epoch(epoch=self.sampler.epoch + 1) |
| 84 | + raise |
| 85 | + |
| 86 | + |
47 | 87 | @dataclass
|
48 | 88 | class DataloaderHparams(hp.Hparams):
|
49 | 89 | """Hyperparameters to initialize a ``torch.utils.data.Dataloader``."""
|
|
0 commit comments