Skip to content

Commit 3ac8fbc

Browse files
ravi-mosaicmlcoryMosaicML
authored andcommitted
Remove composer.trainer.ddp; replace with composer.utils.ddp (mosaicml#105)
Before mosaicml#65, composer.trainer.ddp ensured that DDP functionality was accessed only after ddp was initialized. Now, DDP is available from process start, so this class is no longer needed. Moved all the functionality from this class to the global composer.utils.ddp. This change allows callbacks, algorithms, etc... to use DDP (such as barriers and reductions) as needed. mosaicml#97 and mosaicml#101 depend on this functionality. Also removed DDP from the state, as that is available globally.
1 parent 100a0d9 commit 3ac8fbc

File tree

22 files changed

+366
-384
lines changed

22 files changed

+366
-384
lines changed

composer/algorithms/curriculum_learning/curriculum_learning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from composer.algorithms import AlgorithmHparams
1111
from composer.core.types import Algorithm, Batch, Event, Logger, State, Tensor
1212
from composer.models.transformer_shared import MosaicTransformer
13-
from composer.utils import ensure_tuple
13+
from composer.utils import ddp, ensure_tuple
1414

1515

1616
def apply_curriculum(batch: Dict[str, Tensor], curr_seq_len: int, truncate: bool) -> Batch:
@@ -153,8 +153,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
153153
# all of the parameters
154154
device = next(state.model.parameters()).device
155155

156-
assert (state.train_batch_size % state.world_size) == 0
157-
per_gpu_batch = math.ceil(state.train_batch_size / (state.world_size * state.grad_accum))
156+
assert (state.train_batch_size % ddp.get_world_size()) == 0
157+
per_gpu_batch = math.ceil(state.train_batch_size / (ddp.get_world_size() * state.grad_accum))
158158
input_ids = torch.randint(low=0,
159159
high=vocab_size - 1,
160160
size=(per_gpu_batch, self.hparams.max_seq_length),

composer/algorithms/seq_length_warmup/seq_length_warmup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from composer.algorithms import AlgorithmHparams
1111
from composer.core.types import Algorithm, Batch, Event, Logger, State, Tensor
1212
from composer.models.transformer_shared import MosaicTransformer
13-
from composer.utils import ensure_tuple
13+
from composer.utils import ddp, ensure_tuple
1414

1515

1616
def apply_seq_length_warmup(batch: Dict[str, Tensor], curr_seq_len: int, truncate: bool) -> Batch:
@@ -180,8 +180,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
180180
# all of the parameters
181181
device = next(state.model.parameters()).device
182182

183-
assert (state.train_batch_size % state.world_size) == 0
184-
per_gpu_batch = math.ceil(state.train_batch_size / (state.world_size * state.grad_accum))
183+
assert (state.train_batch_size % ddp.get_world_size()) == 0
184+
per_gpu_batch = math.ceil(state.train_batch_size / (ddp.get_world_size() * state.grad_accum))
185185
input_ids = torch.randint(low=0,
186186
high=vocab_size - 1,
187187
size=(per_gpu_batch, self.hparams.max_seq_length),

composer/callbacks/benchmarker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from composer.callbacks.callback_hparams import BenchmarkerHparams
1212
from composer.core.callback import Callback
1313
from composer.core.types import BreakEpochException
14+
from composer.utils import ddp
1415

1516
log = logging.getLogger(__name__)
1617

@@ -158,7 +159,7 @@ def batch_end(self, state: State, logger: Logger):
158159
now = time.time()
159160
elapsed = now - self.current_time
160161
self.current_time = now
161-
self.profile_examples += state.last_batch_size * state.world_size
162+
self.profile_examples += state.last_batch_size * ddp.get_world_size()
162163
self.profile_steps += 1
163164
self.profile_time += elapsed
164165

composer/callbacks/speed_monitor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from composer.callbacks.callback_hparams import SpeedMonitorHparams
1111
from composer.core.callback import RankZeroCallback
1212
from composer.core.types import StateDict
13+
from composer.utils import ddp
1314

1415

1516
class SpeedMonitor(RankZeroCallback):
@@ -83,7 +84,7 @@ def batch_end(self, state: State, logger: Logger):
8384
# Ideally, callbacks would have a way of reducing tensors.
8485
# It assumes that each process has equal batch sizing
8586
# For the speed monitor, we might be able to use the static step converter with num_samples
86-
batch_num_samples *= state.world_size
87+
batch_num_samples *= ddp.get_world_size()
8788
self.batch_num_samples.append(batch_num_samples)
8889
self.train_examples_per_epoch += batch_num_samples
8990
if len(self.batch_end_times) == self.hparams.window_size + 1:

composer/core/callback.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from typing import TYPE_CHECKING
99

1010
from composer.core.serializable import Serializable
11-
from composer.utils.ddp import is_rank_zero
11+
from composer.utils import ddp
12+
13+
try:
14+
from typing import final
15+
except ImportError:
16+
final = lambda x: x # final is not available in python 3.7
1217

1318
try:
1419
from typing import final
@@ -299,7 +304,7 @@ def eval_end(self, state: State, logger: Logger) -> None:
299304

300305

301306
class RankZeroCallback(Callback, abc.ABC):
302-
"""Base class for callbacks that only run on the rank zero process.
307+
"""Base class for callbacks that only run on the local rank zero process.
303308
304309
Callbacks can be implemented in two ways:
305310
@@ -314,6 +319,6 @@ class RankZeroCallback(Callback, abc.ABC):
314319

315320
@final
316321
def run_event(self, event: Event, state: State, logger: Logger) -> None:
317-
if not is_rank_zero():
322+
if ddp.get_local_rank() != 0:
318323
return
319324
return self._run_event(event, state, logger)

composer/core/logging/base_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING
77

88
from composer.core.callback import Callback, RankZeroCallback
9-
from composer.utils.ddp import is_rank_zero
9+
from composer.utils import ddp
1010

1111
if TYPE_CHECKING:
1212
from composer.core.logging.logger import LogLevel, TLogData
@@ -104,7 +104,7 @@ def _will_log(self, state: State, log_level: LogLevel) -> bool:
104104

105105
@final
106106
def will_log(self, state: State, log_level: LogLevel) -> bool:
107-
if not state.is_rank_zero:
107+
if ddp.get_local_rank() != 0:
108108
return False
109109
return self._will_log(state, log_level)
110110

@@ -126,6 +126,6 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData
126126

127127
@final
128128
def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None:
129-
if not is_rank_zero():
129+
if ddp.get_local_rank() != 0:
130130
return
131131
return self._log_metric(epoch, step, log_level, data)

composer/core/state.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from composer.core.precision import Precision
1717
from composer.core.serializable import Serializable
1818
from composer.utils import ensure_tuple
19-
from composer.utils.ddp import get_global_rank, get_local_rank, get_local_world_size, get_world_size
2019
from composer.utils.precision import default_precision_factory
2120

2221
if TYPE_CHECKING:
@@ -142,26 +141,6 @@ class State(Serializable):
142141
algorithms: Sequence[Algorithm] = tuple()
143142
callbacks: Sequence[Callback] = tuple()
144143

145-
@property
146-
def world_size(self) -> int:
147-
return get_world_size()
148-
149-
@property
150-
def global_rank(self) -> int:
151-
return get_global_rank()
152-
153-
@property
154-
def local_world_size(self) -> int:
155-
return get_local_world_size()
156-
157-
@property
158-
def local_rank(self) -> int:
159-
return get_local_rank()
160-
161-
@property
162-
def is_rank_zero(self) -> bool:
163-
return self.global_rank == 0
164-
165144
def state_dict(self) -> types.StateDict:
166145
"""Returns the state as a :class:`dict`."""
167146
state_dict: types.StateDict = {}

composer/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from composer.datasets.brats import BratsDatasetHparams as BratsDatasetHparams
44
from composer.datasets.cifar10 import CIFAR10DatasetHparams as CIFAR10DatasetHparams
55
from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams
6+
from composer.datasets.dataloader import DDPDataLoader as DDPDataLoader
67
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
78
from composer.datasets.hparams import DataloaderSpec as DataloaderSpec
89
from composer.datasets.hparams import DatasetHparams as DatasetHparams

composer/datasets/dataloader.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from dataclasses import dataclass
6-
from typing import Any, Iterator
7+
from typing import Any, Iterator, Optional
78

89
import torch
910
import torch.distributed
1011
import torch.utils.data
1112
import yahp as hp
13+
from torch.utils.data.distributed import DistributedSampler
1214
from torch.utils.data.sampler import Sampler
1315

1416
from composer.core.types import Batch, DataLoader
@@ -44,6 +46,44 @@ def __setattr__(self, name: str, value: Any) -> None:
4446
return super().__setattr__(name, value)
4547

4648

49+
class DDPDataLoader(WrappedDataLoader):
50+
"""Ensure sampler.set_epoch() is called after each iteration.
51+
52+
DDPDataLoader wraps a dataloader and a distributed sampler and is
53+
called after each iteration (epoch) through the dataset.
54+
See: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
55+
"""
56+
57+
def __init__(self, dataloader: DataLoader) -> None:
58+
super().__init__(dataloader)
59+
if not isinstance(self.dataloader.sampler, DistributedSampler):
60+
raise ValueError("When using the DDP data loader, the sampler must be a DistributedSampler")
61+
self._iterator: Optional[Iterator[Batch]] = None
62+
63+
def __iter__(self) -> DDPDataLoader:
64+
if self._iterator is not None:
65+
warnings.warn(
66+
"DataloaderMultipleIterationWarning: "
67+
"The dataloader detected the start of a new iteration before the previous iteration finished. "
68+
"The dataloader is skipping ahead to the start of the next epoch. "
69+
"Multiple simultaneous iterations through the DDP dataloader prohibited, since "
70+
"it automatically tracks the current epoch.")
71+
assert isinstance(self.sampler, DistributedSampler)
72+
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
73+
self._iterator = iter(self.dataloader)
74+
return self
75+
76+
def __next__(self) -> Batch:
77+
assert self._iterator is not None
78+
try:
79+
return next(self._iterator)
80+
except StopIteration:
81+
self._iterator = None
82+
assert isinstance(self.sampler, DistributedSampler)
83+
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
84+
raise
85+
86+
4787
@dataclass
4888
class DataloaderHparams(hp.Hparams):
4989
"""Hyperparameters to initialize a ``torch.utils.data.Dataloader``."""

composer/trainer/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright 2021 MosaicML. All Rights Reserved.
22

33
from composer.trainer import devices as devices
4-
from composer.trainer.ddp import DDPDataLoader as DDPDataLoader
54
from composer.trainer.trainer import Trainer as Trainer
65
from composer.trainer.trainer_hparams import TrainerHparams as TrainerHparams
76

0 commit comments

Comments
 (0)