Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions composer/algorithms/seq_length_warmup/seq_length_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from composer.algorithms import AlgorithmHparams
from composer.core.types import Algorithm, Batch, Event, Logger, State, Tensor
from composer.models.transformer_shared import MosaicTransformer
from composer.utils import ddp, ensure_tuple
from composer.utils import dist, ensure_tuple


def apply_seq_length_warmup(batch: Dict[str, Tensor], curr_seq_len: int, truncate: bool) -> Batch:
Expand Down Expand Up @@ -180,8 +180,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
# all of the parameters
device = next(state.model.parameters()).device

assert (state.train_batch_size % ddp.get_world_size()) == 0
per_gpu_batch = math.ceil(state.train_batch_size / (ddp.get_world_size() * state.grad_accum))
assert (state.train_batch_size % dist.get_world_size()) == 0
per_gpu_batch = math.ceil(state.train_batch_size / (dist.get_world_size() * state.grad_accum))
input_ids = torch.randint(low=0,
high=vocab_size - 1,
size=(per_gpu_batch, self.hparams.max_seq_length),
Expand Down
4 changes: 2 additions & 2 deletions composer/callbacks/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from composer.core import Logger, State
from composer.core.callback import Callback
from composer.core.types import BreakEpochException
from composer.utils import ddp
from composer.utils import dist

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -159,7 +159,7 @@ def batch_end(self, state: State, logger: Logger):
now = time.time()
elapsed = now - self.current_time
self.current_time = now
self.profile_examples += state.last_batch_size * ddp.get_world_size()
self.profile_examples += state.last_batch_size * dist.get_world_size()
self.profile_steps += 1
self.profile_time += elapsed

Expand Down
14 changes: 7 additions & 7 deletions composer/callbacks/run_directory_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from composer.core.logging import Logger
from composer.core.logging.logger import LogLevel
from composer.core.state import State
from composer.utils import ddp
from composer.utils import dist
from composer.utils.run_directory import get_modified_files, get_run_directory

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -148,13 +148,13 @@ def __init__(
self._finished: Union[None, multiprocessing._EventType, threading.Event] = None
self._workers = []

if ddp.get_local_rank() == 0:
if dist.get_local_rank() == 0:
_validate_credentials(provider, container, self._object_name_prefix, provider_init_kwargs)

def init(self, state: State, logger: Logger) -> None:
if get_run_directory() is None:
return
if not ddp.get_local_rank() == 0:
if not dist.get_local_rank() == 0:
return
del state, logger # unused
self._finished = self._finished_cls()
Expand Down Expand Up @@ -194,7 +194,7 @@ def training_end(self, state: State, logger: Logger) -> None:
def post_close(self):
# Cleaning up on post_close to ensure that all artifacts are uploaded
self._trigger_upload(logger=None, log_level=None)
if not ddp.get_local_rank() == 0:
if not dist.get_local_rank() == 0:
return
if self._finished is not None:
self._finished.set()
Expand All @@ -207,8 +207,8 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel
# Ensure that every rank is at this point
# Assuming only the main thread on each rank writes to the run directory, then the barrier here will ensure
# that the run directory is not being modified after we pass this barrier
ddp.barrier()
if ddp.get_local_rank() == 0:
dist.barrier()
if dist.get_local_rank() == 0:
run_directory = get_run_directory()
assert run_directory is not None, "invariant error"
# the disk time can differ from system time, so going to touch a file and then read the timestamp from it to get the real time
Expand Down Expand Up @@ -244,7 +244,7 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel
logger.metric(log_level, {"run_directory/uploaded_files": files_to_be_uploaded})

# Ensure that other callbacks do not start writing to the run directory until we copying everything
ddp.barrier()
dist.barrier()


def _validate_credentials(
Expand Down
4 changes: 2 additions & 2 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from composer.core import Logger, State
from composer.core.callback import RankZeroCallback
from composer.core.types import StateDict
from composer.utils import ddp
from composer.utils import dist


class SpeedMonitor(RankZeroCallback):
Expand Down Expand Up @@ -84,7 +84,7 @@ def batch_end(self, state: State, logger: Logger):
# Ideally, callbacks would have a way of reducing tensors.
# It assumes that each process has equal batch sizing
# For the speed monitor, we might be able to use the static step converter with num_samples
batch_num_samples *= ddp.get_world_size()
batch_num_samples *= dist.get_world_size()
self.batch_num_samples.append(batch_num_samples)
self.train_examples_per_epoch += batch_num_samples
if len(self.batch_end_times) == self.hparams.window_size + 1:
Expand Down
2 changes: 1 addition & 1 deletion composer/callbacks/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from composer.callbacks.callback_hparams import TorchProfilerHparams
from composer.core import Callback, Logger, State
from composer.core.types import StateDict
from composer.utils.ddp import get_global_rank
from composer.utils.dist import get_global_rank
from composer.utils.run_directory import get_relative_to_run_directory

_PROFILE_MISSING_ERROR = "The profiler has not been setup. Please call profiler.training_start() before training starts."
Expand Down
4 changes: 2 additions & 2 deletions composer/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING

from composer.core.serializable import Serializable
from composer.utils import ddp
from composer.utils import dist

try:
from typing import final
Expand Down Expand Up @@ -333,6 +333,6 @@ class RankZeroCallback(Callback, abc.ABC):

@final
def run_event(self, event: Event, state: State, logger: Logger) -> None:
if ddp.get_local_rank() != 0:
if dist.get_local_rank() != 0:
return
return self._run_event(event, state, logger)
6 changes: 3 additions & 3 deletions composer/core/logging/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING

from composer.core.callback import Callback, RankZeroCallback
from composer.utils import ddp
from composer.utils import dist

if TYPE_CHECKING:
from composer.core.logging.logger import LogLevel, TLogData
Expand Down Expand Up @@ -104,7 +104,7 @@ def _will_log(self, state: State, log_level: LogLevel) -> bool:

@final
def will_log(self, state: State, log_level: LogLevel) -> bool:
if ddp.get_local_rank() != 0:
if dist.get_local_rank() != 0:
return False
return self._will_log(state, log_level)

Expand All @@ -126,6 +126,6 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData

@final
def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None:
if ddp.get_local_rank() != 0:
if dist.get_local_rank() != 0:
return
return self._log_metric(epoch, step, log_level, data)
6 changes: 3 additions & 3 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import composer.core.types as types
from composer.core.precision import Precision
from composer.core.serializable import Serializable
from composer.utils import ddp, ensure_tuple
from composer.utils import dist, ensure_tuple
from composer.utils.precision import default_precision_factory

if TYPE_CHECKING:
Expand Down Expand Up @@ -185,14 +185,14 @@ def train_batch_size(self):
"""The global batch size used for training."""
if self.train_dataloader.batch_size is None:
raise RuntimeError("train dataloader batch size is undefined")
return self.train_dataloader.batch_size * ddp.get_world_size()
return self.train_dataloader.batch_size * dist.get_world_size()

@property
def eval_batch_size(self):
"""The batch size used for evaluation."""
if self.eval_dataloader.batch_size is None:
raise RuntimeError("eval dataloader batch size is undefined")
return self.eval_dataloader.batch_size * ddp.get_world_size()
return self.eval_dataloader.batch_size * dist.get_world_size()

def state_dict(self) -> types.StateDict:
"""Returns the state as a :class:`dict`."""
Expand Down
4 changes: 2 additions & 2 deletions composer/datasets/brats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from composer.core.types import DataLoader, Dataset
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DatasetHparams
from composer.utils import ddp
from composer.utils import dist

PATCH_SIZE = [1, 192, 160]

Expand Down Expand Up @@ -48,7 +48,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
x_train, y_train, x_val, y_val = get_data_split(self.datadir)
dataset = PytTrain(x_train, y_train, oversampling) if self.is_train else PytVal(x_val, y_val)
collate_fn = None if self.is_train else _my_collate
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

return dataloader_hparams.initialize_object(
dataset=dataset,
Expand Down
4 changes: 2 additions & 2 deletions composer/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.utils import ddp
from composer.utils import dist


@dataclass
Expand Down Expand Up @@ -59,7 +59,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
download=self.download,
transform=transformation,
)
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

return dataloader_hparams.initialize_object(dataset,
batch_size=batch_size,
Expand Down
4 changes: 2 additions & 2 deletions composer/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DataloaderSpec, DatasetHparams, SyntheticHparamsMixin
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.utils import ddp
from composer.utils import dist


class TransformationFn:
Expand Down Expand Up @@ -119,7 +119,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
if self.datadir is None:
raise ValueError("datadir must be specified is self.synthetic is False")
dataset = ImageFolder(os.path.join(self.datadir, split), transformation)
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

return DataloaderSpec(dataloader=dataloader_hparams.initialize_object(
dataset=dataset,
Expand Down
4 changes: 2 additions & 2 deletions composer/datasets/lm_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from composer.core.types import Batch
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DataloaderSpec, DatasetHparams
from composer.utils import ddp
from composer.utils import dist

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,7 +102,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
dataset = lm_datasets
data_collator = transformers.default_data_collator

sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

return DataloaderSpec(dataloader=dataloader_hparams.initialize_object(
dataset=dataset,
Expand Down
4 changes: 2 additions & 2 deletions composer/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.utils import ddp
from composer.utils import dist


@dataclass
Expand Down Expand Up @@ -43,7 +43,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
download=self.download,
transform=transform,
)
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
return dataloader_hparams.initialize_object(dataset=dataset,
batch_size=batch_size,
sampler=sampler,
Expand Down
16 changes: 8 additions & 8 deletions composer/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from composer.core.logging import BaseLoggerBackend, LogLevel, TLogData
from composer.core.types import Logger, State, StateDict
from composer.utils import ddp, run_directory
from composer.utils import dist, run_directory

import wandb # isort: skip

Expand Down Expand Up @@ -40,18 +40,18 @@ def __init__(self,

def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData):
del epoch, log_level # unused
if ddp.get_local_rank() == 0:
if dist.get_local_rank() == 0:
wandb.log(data, step=step)

def state_dict(self) -> StateDict:
# Storing these fields in the state dict to support run resuming in the future.
if ddp.get_local_rank() != 0:
if dist.get_local_rank() != 0:
raise RuntimeError("WandB can only be checkpointed on rank 0")
return {"name": wandb.run.name, "project": wandb.run.project, "entity": wandb.run.entity, "id": wandb.run.id}

def init(self, state: State, logger: Logger) -> None:
del state, logger # unused
if ddp.get_local_rank() == 0:
if dist.get_local_rank() == 0:
wandb.init(**self._init_params)

def batch_end(self, state: State, logger: Logger) -> None:
Expand Down Expand Up @@ -82,8 +82,8 @@ def _upload_artifacts(self):
return
# barrier that every process has reached this point and that
# previous callbacks are finished writing to the run directory
ddp.barrier()
if ddp.get_local_rank() == 0:
dist.barrier()
if dist.get_local_rank() == 0:
modified_files = run_directory.get_modified_files(self._last_upload_timestamp)
for modified_file in modified_files:
file_type = modified_file.split(".")[-1]
Expand All @@ -95,7 +95,7 @@ def _upload_artifacts(self):
self._last_upload_timestamp = run_directory.get_run_directory_timestamp()
# barrier to ensure that other processes do not continue to other callbacks
# that could start writing to the run directory
ddp.barrier()
dist.barrier()

def post_close(self) -> None:
# Cleaning up on post_close so all artifacts are uploaded
Expand All @@ -104,7 +104,7 @@ def post_close(self) -> None:

exc_tpe, exc_info, tb = sys.exc_info()

if ddp.get_local_rank() == 0:
if dist.get_local_rank() == 0:
if (exc_tpe, exc_info, tb) == (None, None, None):
wandb.finish(0)
else:
Expand Down
Loading