Skip to content

Commit 1566ce2

Browse files
authored
Distinguish between dist and DDP (#201)
* dist, not ddp * simplify ClosureGradScaler * formatting * formatting and more fixes * that did not save * small fixes * dont need to worry about circular dependencies any longer * dumb pyright fix * woops
1 parent 995fafe commit 1566ce2

28 files changed

+251
-242
lines changed

composer/algorithms/seq_length_warmup/seq_length_warmup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from composer.algorithms import AlgorithmHparams
1111
from composer.core.types import Algorithm, Batch, Event, Logger, State, Tensor
1212
from composer.models.transformer_shared import MosaicTransformer
13-
from composer.utils import ddp, ensure_tuple
13+
from composer.utils import dist, ensure_tuple
1414

1515

1616
def apply_seq_length_warmup(batch: Dict[str, Tensor], curr_seq_len: int, truncate: bool) -> Batch:
@@ -180,8 +180,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
180180
# all of the parameters
181181
device = next(state.model.parameters()).device
182182

183-
assert (state.train_batch_size % ddp.get_world_size()) == 0
184-
per_gpu_batch = math.ceil(state.train_batch_size / (ddp.get_world_size() * state.grad_accum))
183+
assert (state.train_batch_size % dist.get_world_size()) == 0
184+
per_gpu_batch = math.ceil(state.train_batch_size / (dist.get_world_size() * state.grad_accum))
185185
input_ids = torch.randint(low=0,
186186
high=vocab_size - 1,
187187
size=(per_gpu_batch, self.hparams.max_seq_length),

composer/callbacks/benchmarker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from composer.core import Logger, State
1212
from composer.core.callback import Callback
1313
from composer.core.types import BreakEpochException
14-
from composer.utils import ddp
14+
from composer.utils import dist
1515

1616
log = logging.getLogger(__name__)
1717

@@ -159,7 +159,7 @@ def batch_end(self, state: State, logger: Logger):
159159
now = time.time()
160160
elapsed = now - self.current_time
161161
self.current_time = now
162-
self.profile_examples += state.last_batch_size * ddp.get_world_size()
162+
self.profile_examples += state.last_batch_size * dist.get_world_size()
163163
self.profile_steps += 1
164164
self.profile_time += elapsed
165165

composer/callbacks/run_directory_uploader.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from composer.core.logging import Logger
2121
from composer.core.logging.logger import LogLevel
2222
from composer.core.state import State
23-
from composer.utils import ddp
23+
from composer.utils import dist
2424
from composer.utils.run_directory import get_modified_files, get_run_directory
2525

2626
log = logging.getLogger(__name__)
@@ -148,13 +148,13 @@ def __init__(
148148
self._finished: Union[None, multiprocessing._EventType, threading.Event] = None
149149
self._workers = []
150150

151-
if ddp.get_local_rank() == 0:
151+
if dist.get_local_rank() == 0:
152152
_validate_credentials(provider, container, self._object_name_prefix, provider_init_kwargs)
153153

154154
def init(self, state: State, logger: Logger) -> None:
155155
if get_run_directory() is None:
156156
return
157-
if not ddp.get_local_rank() == 0:
157+
if not dist.get_local_rank() == 0:
158158
return
159159
del state, logger # unused
160160
self._finished = self._finished_cls()
@@ -194,7 +194,7 @@ def training_end(self, state: State, logger: Logger) -> None:
194194
def post_close(self):
195195
# Cleaning up on post_close to ensure that all artifacts are uploaded
196196
self._trigger_upload(logger=None, log_level=None)
197-
if not ddp.get_local_rank() == 0:
197+
if not dist.get_local_rank() == 0:
198198
return
199199
if self._finished is not None:
200200
self._finished.set()
@@ -207,8 +207,8 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel
207207
# Ensure that every rank is at this point
208208
# Assuming only the main thread on each rank writes to the run directory, then the barrier here will ensure
209209
# that the run directory is not being modified after we pass this barrier
210-
ddp.barrier()
211-
if ddp.get_local_rank() == 0:
210+
dist.barrier()
211+
if dist.get_local_rank() == 0:
212212
run_directory = get_run_directory()
213213
assert run_directory is not None, "invariant error"
214214
# the disk time can differ from system time, so going to touch a file and then read the timestamp from it to get the real time
@@ -244,7 +244,7 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel
244244
logger.metric(log_level, {"run_directory/uploaded_files": files_to_be_uploaded})
245245

246246
# Ensure that other callbacks do not start writing to the run directory until we copying everything
247-
ddp.barrier()
247+
dist.barrier()
248248

249249

250250
def _validate_credentials(

composer/callbacks/speed_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from composer.core import Logger, State
1111
from composer.core.callback import RankZeroCallback
1212
from composer.core.types import StateDict
13-
from composer.utils import ddp
13+
from composer.utils import dist
1414

1515

1616
class SpeedMonitor(RankZeroCallback):
@@ -84,7 +84,7 @@ def batch_end(self, state: State, logger: Logger):
8484
# Ideally, callbacks would have a way of reducing tensors.
8585
# It assumes that each process has equal batch sizing
8686
# For the speed monitor, we might be able to use the static step converter with num_samples
87-
batch_num_samples *= ddp.get_world_size()
87+
batch_num_samples *= dist.get_world_size()
8888
self.batch_num_samples.append(batch_num_samples)
8989
self.train_examples_per_epoch += batch_num_samples
9090
if len(self.batch_end_times) == self.hparams.window_size + 1:

composer/callbacks/torch_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from composer.callbacks.callback_hparams import TorchProfilerHparams
1313
from composer.core import Callback, Logger, State
1414
from composer.core.types import StateDict
15-
from composer.utils.ddp import get_global_rank
15+
from composer.utils.dist import get_global_rank
1616
from composer.utils.run_directory import get_relative_to_run_directory
1717

1818
_PROFILE_MISSING_ERROR = "The profiler has not been setup. Please call profiler.training_start() before training starts."

composer/core/callback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import TYPE_CHECKING
99

1010
from composer.core.serializable import Serializable
11-
from composer.utils import ddp
11+
from composer.utils import dist
1212

1313
try:
1414
from typing import final
@@ -333,6 +333,6 @@ class RankZeroCallback(Callback, abc.ABC):
333333

334334
@final
335335
def run_event(self, event: Event, state: State, logger: Logger) -> None:
336-
if ddp.get_local_rank() != 0:
336+
if dist.get_local_rank() != 0:
337337
return
338338
return self._run_event(event, state, logger)

composer/core/logging/base_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING
77

88
from composer.core.callback import Callback, RankZeroCallback
9-
from composer.utils import ddp
9+
from composer.utils import dist
1010

1111
if TYPE_CHECKING:
1212
from composer.core.logging.logger import LogLevel, TLogData
@@ -104,7 +104,7 @@ def _will_log(self, state: State, log_level: LogLevel) -> bool:
104104

105105
@final
106106
def will_log(self, state: State, log_level: LogLevel) -> bool:
107-
if ddp.get_local_rank() != 0:
107+
if dist.get_local_rank() != 0:
108108
return False
109109
return self._will_log(state, log_level)
110110

@@ -126,6 +126,6 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData
126126

127127
@final
128128
def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None:
129-
if ddp.get_local_rank() != 0:
129+
if dist.get_local_rank() != 0:
130130
return
131131
return self._log_metric(epoch, step, log_level, data)

composer/core/state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import composer.core.types as types
1414
from composer.core.precision import Precision
1515
from composer.core.serializable import Serializable
16-
from composer.utils import ddp, ensure_tuple
16+
from composer.utils import dist, ensure_tuple
1717
from composer.utils.precision import default_precision_factory
1818

1919
if TYPE_CHECKING:
@@ -185,14 +185,14 @@ def train_batch_size(self):
185185
"""The global batch size used for training."""
186186
if self.train_dataloader.batch_size is None:
187187
raise RuntimeError("train dataloader batch size is undefined")
188-
return self.train_dataloader.batch_size * ddp.get_world_size()
188+
return self.train_dataloader.batch_size * dist.get_world_size()
189189

190190
@property
191191
def eval_batch_size(self):
192192
"""The batch size used for evaluation."""
193193
if self.eval_dataloader.batch_size is None:
194194
raise RuntimeError("eval dataloader batch size is undefined")
195-
return self.eval_dataloader.batch_size * ddp.get_world_size()
195+
return self.eval_dataloader.batch_size * dist.get_world_size()
196196

197197
def state_dict(self) -> types.StateDict:
198198
"""Returns the state as a :class:`dict`."""

composer/datasets/brats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from composer.core.types import DataLoader, Dataset
1515
from composer.datasets.dataloader import DataloaderHparams
1616
from composer.datasets.hparams import DatasetHparams
17-
from composer.utils import ddp
17+
from composer.utils import dist
1818

1919
PATCH_SIZE = [1, 192, 160]
2020

@@ -48,7 +48,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
4848
x_train, y_train, x_val, y_val = get_data_split(self.datadir)
4949
dataset = PytTrain(x_train, y_train, oversampling) if self.is_train else PytVal(x_val, y_val)
5050
collate_fn = None if self.is_train else _my_collate
51-
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
51+
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
5252

5353
return dataloader_hparams.initialize_object(
5454
dataset=dataset,

composer/datasets/cifar10.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from composer.datasets.dataloader import DataloaderHparams
1111
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin
1212
from composer.datasets.synthetic import SyntheticBatchPairDataset
13-
from composer.utils import ddp
13+
from composer.utils import dist
1414

1515

1616
@dataclass
@@ -59,7 +59,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
5959
download=self.download,
6060
transform=transformation,
6161
)
62-
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
62+
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
6363

6464
return dataloader_hparams.initialize_object(dataset,
6565
batch_size=batch_size,

0 commit comments

Comments
 (0)