diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index b9c10ff4d1..b1c2b25abd 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -1,6 +1,5 @@ # Copyright 2021 MosaicML. All Rights Reserved. -import math from dataclasses import asdict, dataclass from typing import Dict, Mapping, Optional @@ -10,7 +9,7 @@ from composer.algorithms import AlgorithmHparams from composer.core.types import Algorithm, Batch, Event, Logger, State, Tensor from composer.models.transformer_shared import MosaicTransformer -from composer.utils import dist, ensure_tuple +from composer.utils import ensure_tuple def apply_seq_length_warmup(batch: Dict[str, Tensor], curr_seq_len: int, truncate: bool) -> Batch: @@ -180,8 +179,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: # all of the parameters device = next(state.model.parameters()).device - assert (state.train_batch_size % dist.get_world_size()) == 0 - per_gpu_batch = math.ceil(state.train_batch_size / (dist.get_world_size() * state.grad_accum)) + per_gpu_macrobatch = state.train_dataloader.batch_size + if per_gpu_macrobatch is None: + raise RuntimeError("seq_length_warmup requires constant batch sizing") + assert per_gpu_macrobatch % state.grad_accum == 0, "grad accum should evenly divide the batch" + per_gpu_batch = per_gpu_macrobatch // state.grad_accum + input_ids = torch.randint(low=0, high=vocab_size - 1, size=(per_gpu_batch, self.hparams.max_seq_length), diff --git a/composer/callbacks/benchmarker.py b/composer/callbacks/benchmarker.py index 3503116727..2228499139 100644 --- a/composer/callbacks/benchmarker.py +++ b/composer/callbacks/benchmarker.py @@ -7,11 +7,14 @@ import warnings from typing import Sequence +import torch + from composer.callbacks.callback_hparams import BenchmarkerHparams from composer.core import Logger, State from composer.core.callback import Callback from composer.core.types import BreakEpochException from composer.utils import dist +from composer.utils.data import get_device_of_batch log = logging.getLogger(__name__) @@ -158,7 +161,9 @@ def batch_end(self, state: State, logger: Logger): now = time.time() elapsed = now - self.current_time self.current_time = now - self.profile_examples += state.last_batch_size * dist.get_world_size() + batch_num_samples = torch.tensor([int(state.batch_num_samples)], device=get_device_of_batch(state.batch)) + dist.all_reduce(batch_num_samples, reduce_operation="SUM") + self.profile_examples += int(batch_num_samples.item()) self.profile_steps += 1 self.profile_time += elapsed diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index dbb51b7ba9..9bfc6986ff 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -6,11 +6,14 @@ from collections import deque from typing import Deque, Optional +import torch + from composer.callbacks.callback_hparams import SpeedMonitorHparams from composer.core import Logger, State from composer.core.callback import RankZeroCallback from composer.core.types import StateDict from composer.utils import dist +from composer.utils.data import get_device_of_batch class SpeedMonitor(RankZeroCallback): @@ -78,14 +81,9 @@ def epoch_start(self, state: State, logger: Logger): def batch_end(self, state: State, logger: Logger): self.batch_end_times.append(time.time()) - batch_num_samples = 0 - batch_num_samples += state.last_batch_size - # TODO this is a hack around not having proper syncing / reduction available in callbacks - # Ideally, callbacks would have a way of reducing tensors. - # It assumes that each process has equal batch sizing - # For the speed monitor, we might be able to use the static step converter with num_samples - batch_num_samples *= dist.get_world_size() - self.batch_num_samples.append(batch_num_samples) + batch_num_samples = torch.tensor([int(state.batch_num_samples)], device=get_device_of_batch(state.batch)) + dist.all_reduce(batch_num_samples, reduce_operation="SUM") + self.batch_num_samples.append(int(batch_num_samples.item())) self.train_examples_per_epoch += batch_num_samples if len(self.batch_end_times) == self.hparams.window_size + 1: throughput = sum(self.batch_num_samples) / (self.batch_end_times[-1] - self.batch_end_times[0]) diff --git a/composer/core/state.py b/composer/core/state.py index 20317b5a5d..a98227d831 100755 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -5,18 +5,17 @@ import logging import textwrap import warnings -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Callable, ContextManager, Optional, Sequence, Union, cast import torch import torch.nn.modules.utils from torch.nn.parallel import DistributedDataParallel import composer.core.types as types -from composer.core.data_spec import DataSpec from composer.core.precision import Precision from composer.core.serializable import Serializable from composer.core.time import Time, Timer, TimeUnit -from composer.utils import dist, ensure_tuple +from composer.utils import ensure_tuple from composer.utils.precision import default_precision_factory if TYPE_CHECKING: @@ -55,9 +54,13 @@ SKIP_SERIALIZATION_FIELDS = [ "loss", "batch", + "batch_num_samples", + "batch_num_tokens", + "microbatches", + "microbatch_idx", "outputs", - "train_data", - "eval_data", + "train_dataloader", + "eval_dataloader", "_steps_per_epoch", "_precision_context", ] @@ -91,14 +94,26 @@ class State(Serializable): callbacks (Sequence[Callback]): The callbacks used for training. Attributes: - batch (types.Batch): The most recently retrieved batch. - last_batch_size (int): The size of the batch last returned from the dataloader. This can be different from the current size of ``batch`` if algorithms have modified the ``batch``. + batch (types.Batch): The batch. This will be the entire batch during the :attr:`Event.AFTER_DATALOADER`, or a + microbatch between :attr:`Event.BATCH_START` and :attr:`Event.BATCH_END`. + batch_num_samples (int): The number of samples in the :attr:`batch`. + batch_num_tokens (int): The number of tokens in the :attr:`batch`. + microbatches (Sequence[type.Batch]): The batch, split into ``grad_accum`` microbatches. + microbatch_idx (int): The current microbatch index, which will be on [0, ``grad_accum``). + loss (types.Tensors): The most recently computed loss. outputs (types.Tensors): The most recently computed output from the model's forward pass. timer (types.Timer): The timer that tracks training loop progress. """ _max_duration: Time[int] + batch: types.Batch + batch_num_samples: int + batch_num_tokens: int + microbatches: Sequence[types.Batch] + microbatch_idx: int + loss: types.Tensors + outputs: types.Tensors def __init__( self, @@ -107,8 +122,8 @@ def __init__( # data configurations grad_accum: int, - train_dataloader: Union[types.DataLoader, types.DataSpec, Dict[str, Any]], - eval_dataloader: Union[types.DataLoader, types.DataSpec, Dict[str, Any]], + train_dataloader: types.DataLoader, + eval_dataloader: types.DataLoader, # stopping conditions max_duration: Union[str, Time[int]], @@ -130,18 +145,8 @@ def __init__( ): self.model = model self.grad_accum = grad_accum - if not isinstance(train_dataloader, DataSpec): - if isinstance(train_dataloader, dict): - train_dataloader = DataSpec(**train_dataloader) - else: - train_dataloader = DataSpec(train_dataloader) - self.train_data = train_dataloader - if not isinstance(eval_dataloader, DataSpec): - if isinstance(eval_dataloader, dict): - eval_dataloader = DataSpec(**eval_dataloader) - else: - eval_dataloader = DataSpec(eval_dataloader) - self.eval_data = eval_dataloader + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader self.max_duration = max_duration self.timer = Timer() @@ -149,11 +154,6 @@ def __init__( self._steps_per_epoch = None self._precision_context = precision_context - self.batch: types.Batch = {} - self.last_batch_size = 0 - self.loss: types.Tensors = torch.zeros(size=(1,)) - self.outputs: types.Tensors = torch.zeros(size=(1,)) - if optimizers is None: self._optimizers = [] else: @@ -212,22 +212,6 @@ def max_epochs(self): assert self.max_duration.unit == TimeUnit.EPOCH, "invariant violation -- max duration must be epochs for now" return self.max_duration.value - @property - def train_dataloader(self): - return self.train_data.dataloader - - @train_dataloader.setter - def train_dataloader(self, dataloader: types.DataLoader): - self.train_data = DataSpec(dataloader) - - @property - def eval_dataloader(self): - return self.eval_data.dataloader - - @eval_dataloader.setter - def eval_dataloader(self, dataloader: types.DataLoader): - self.eval_data = DataSpec(dataloader) - @property def optimizers(self): return self._optimizers @@ -260,20 +244,6 @@ def algorithms(self): def algorithms(self, algorithms: Sequence[Algorithm]): self._algorithms[:] = algorithms - @property - def train_batch_size(self): - """The global batch size used for training.""" - if self.train_dataloader.batch_size is None: - raise RuntimeError("train dataloader batch size is undefined") - return self.train_dataloader.batch_size * dist.get_world_size() - - @property - def eval_batch_size(self): - """The batch size used for evaluation.""" - if self.eval_dataloader.batch_size is None: - raise RuntimeError("eval dataloader batch size is undefined") - return self.eval_dataloader.batch_size * dist.get_world_size() - def state_dict(self) -> types.StateDict: """Returns the state as a :class:`dict`.""" state_dict: types.StateDict = {} diff --git a/composer/trainer/deepspeed.py b/composer/trainer/deepspeed.py index 010e53304f..484902df32 100755 --- a/composer/trainer/deepspeed.py +++ b/composer/trainer/deepspeed.py @@ -50,9 +50,11 @@ def validate(self): warnings.warn("ZeRO stage 3 is largely untested with composer. Certain algorithms may break.") def initialize_object(self, state: State, grad_clip_norm: Optional[float]): + if state.train_dataloader.batch_size is None: + raise RuntimeError("Deepspeed requires a dataloader with a known batch size") deepspeed_config: dict[str, Any] = { - "train_batch_size": state.train_batch_size, + "train_micro_batch_size_per_gpu": state.train_dataloader.batch_size // state.grad_accum, "gradient_accumulation_steps": state.grad_accum, "zero_optimization": { "stage": self.zero_stage, diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 101bee337a..ef2ded8d1f 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -8,7 +8,7 @@ import logging import textwrap import warnings -from typing import Any, Callable, ContextManager, Dict, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Sequence, Union, cast import torch import torch.distributed @@ -21,7 +21,7 @@ from composer.core import Callback, DataSpec, Engine, Event, Logger, State, Time from composer.core.algorithm import Algorithm from composer.core.logging import BaseLoggerBackend, LogLevel -from composer.core.types import Batch, BreakEpochException, DataLoader, Metrics, Precision, Tensor +from composer.core.types import BreakEpochException, DataLoader, Metrics, Precision from composer.datasets.dataloader import DDPDataLoader from composer.loggers.tqdm_logger import TQDMLoggerBackend from composer.models.base import BaseMosaicModel @@ -38,6 +38,9 @@ from composer.trainer.trainer_hparams import TrainerHparams from composer.utils import dist, ensure_tuple, map_collection, reproducibility +if TYPE_CHECKING: + import deepspeed + log = logging.getLogger(__name__) @@ -214,6 +217,9 @@ def __init__( eval_dataloader = DataSpec(eval_dataloader) eval_dataloader.dataloader = DDPDataLoader(eval_dataloader.dataloader) + self._train_data_spec = train_dataloader + self._eval_data_spec = eval_dataloader + self.state = State( max_duration=max_duration, algorithms=algorithms, @@ -222,8 +228,8 @@ def __init__( grad_accum=grad_accum, precision=precision, precision_context=precision_context, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, + train_dataloader=train_dataloader.dataloader, + eval_dataloader=eval_dataloader.dataloader, ) # Steps per epoch @@ -272,17 +278,18 @@ def __init__( if not isinstance(schedulers_hparams, list): schedulers_hparams = [schedulers_hparams] optimizer = optimizer_hparams.initialize_object(param_group=self.state.model.parameters()) - if self.state.train_data.num_samples is None: + if self._train_data_spec.num_samples is None or self.state.train_dataloader.batch_size is None: samples_per_epoch = None else: - samples_per_epoch = min(self.state.steps_per_epoch * self.state.train_batch_size, - self.state.train_data.num_samples) + batch_size = self.state.train_dataloader.batch_size * dist.get_world_size() + + samples_per_epoch = min(self.state.steps_per_epoch * batch_size, self._train_data_spec.num_samples) schedulers = [ x.initialize_object(optimizer=optimizer, max_training_duration=self.state.max_duration, steps_per_epoch=self.state.steps_per_epoch, samples_per_epoch=samples_per_epoch, - dataset_num_tokens=self.state.train_data.num_tokens) + dataset_num_tokens=self._train_data_spec.num_tokens) for x in ensure_warmup_last(schedulers_hparams) ] self.state.optimizers = optimizer @@ -557,9 +564,11 @@ def _train_loop(self) -> None: self.checkpoint_loader.restore_checkpoint_rng_state(self.device) continue - state.last_batch_size = state.train_data.get_num_samples_in_batch(state.batch) state.batch = self.device.batch_to_device(state.batch) - state.batch = state.train_data.device_transforms(state.batch) + state.batch = self._train_data_spec.device_transforms(state.batch) + state.batch_num_samples = self._train_data_spec.get_num_samples_in_batch(state.batch) + state.batch_num_tokens = self._train_data_spec.get_num_tokens_in_batch(state.batch) + state.microbatches = self._train_data_spec.split_batch(state.batch, state.grad_accum) if self.deepspeed_enabled: state.batch = fix_batch_precision_for_deepspeed(state.batch, state.precision) @@ -569,8 +578,7 @@ def _train_loop(self) -> None: assert train_metrics is not None state.model.eval() with torch.no_grad(): - eval_microbatches = state.train_data.split_batch(state.batch, state.grad_accum) - for eval_microbatch in eval_microbatches: + for eval_microbatch in state.microbatches: # TODO: Detect if self.run_event(Event.AFTER_DATALOADER) changes the training # data and if so print a warning that metrics may return unexpected results outputs, targets = self.original_model.validate(eval_microbatch) @@ -580,7 +588,12 @@ def _train_loop(self) -> None: self.engine.run_event(Event.AFTER_DATALOADER) - microbatches = state.train_data.split_batch(state.batch, state.grad_accum) + num_samples_in_batch = self.device.tensor_to_device( + torch.tensor([state.batch_num_samples], dtype=torch.int)) + num_tokens_in_batch = self.device.tensor_to_device( + torch.tensor([state.batch_num_tokens], dtype=torch.int)) + dist.all_reduce(num_samples_in_batch, reduce_operation="SUM") + dist.all_reduce(num_tokens_in_batch, reduce_operation="SUM") self.engine.run_event(Event.BATCH_START) self.logger.metric_batch({ @@ -589,18 +602,15 @@ def _train_loop(self) -> None: }) total_loss = None if self.deepspeed_enabled: - total_loss = self._train_batch(microbatches) + total_loss = self._train_batch() elif self._use_closures(): - closure = lambda **kwargs: self._train_batch(microbatches, **kwargs) for optimizer in state.optimizers: if use_grad_scaling: - total_loss = state.scaler.step(optimizer, closure=closure) + total_loss = state.scaler.step(optimizer, closure=self._train_batch) else: - # Torch optimizers technically expect closures to return a float, not a Tensor. - # In practice, this doesn't seem to actually matter. - total_loss = optimizer.step(closure=closure) # type: ignore + total_loss = optimizer.step(closure=lambda: self._train_batch().item()) else: - total_loss = self._train_batch(microbatches) + total_loss = self._train_batch() for optimizer in state.optimizers: if use_grad_scaling: state.scaler.step(optimizer) @@ -611,7 +621,8 @@ def _train_loop(self) -> None: state.scaler.update() if total_loss is not None: - assert isinstance(total_loss, Tensor) + if not isinstance(total_loss, torch.Tensor): + total_loss = self.device.tensor_to_device(torch.tensor([total_loss])) # total_loss can be None if gradient scaling failed dist.all_reduce(total_loss, reduce_operation="SUM") @@ -632,8 +643,8 @@ def _train_loop(self) -> None: self.eval(is_batch=True) state.timer.on_batch_complete( - samples=state.train_data.get_num_samples_in_batch(state.batch), - tokens=state.train_data.get_num_tokens_in_batch(state.batch), + samples=int(num_samples_in_batch.item()), + tokens=int(num_tokens_in_batch.item()), ) if self.checkpoint_saver and self.checkpoint_saver.should_checkpoint(state=state, event=Event.BATCH_END): @@ -662,7 +673,7 @@ def _train_loop(self) -> None: self.engine.run_event(Event.TRAINING_END) - def _train_batch(self, microbatches: Sequence[Batch], ddp_sync: bool = True): + def _train_batch(self, ddp_sync: bool = True): """Run training on a full batch of data. Args: @@ -677,14 +688,11 @@ def _train_batch(self, microbatches: Sequence[Batch], ddp_sync: bool = True): context = self.state.model.no_sync with context(): # type: ignore - Pyright apparently doesn't recognize no_sync - return self._train_batch_inner(microbatches) + return self._train_batch_inner() - def _train_batch_inner(self, microbatches: Sequence[Batch]): + def _train_batch_inner(self): """Iterate over microbatches and compute the loss that will be used to step the optimizer. - - Args: - microbatches (Sequence[Batch]): The microbatches which make up the batch. """ self.engine.run_event(Event.BEFORE_TRAIN_BATCH) @@ -700,15 +708,16 @@ def _train_batch_inner(self, microbatches: Sequence[Batch]): # tracker for gradient accumulation total_loss = self.device.tensor_to_device(torch.zeros(size=(1,))) - current_batch_size = sum([state.train_data.get_num_samples_in_batch(batch) for batch in microbatches]) + current_batch_size = sum( + [self._train_data_spec.get_num_samples_in_batch(batch) for batch in state.microbatches]) - for microbatch_idx, state.batch in enumerate(microbatches): - is_final_microbatch = microbatch_idx + 1 == len(microbatches) + for state.microbatch_idx, state.batch in enumerate(state.microbatches): + state.batch_num_tokens = self._train_data_spec.get_num_tokens_in_batch(state.batch) + state.batch_num_samples = self._train_data_spec.get_num_samples_in_batch(state.batch) + is_final_microbatch = state.microbatch_idx + 1 == len(state.microbatches) sync_context = contextlib.nullcontext() if self.deepspeed_enabled else ddp_sync_context( state, is_final_microbatch, self.ddp_sync_strategy) with sync_context: - last_microbatch_size = state.train_data.get_num_samples_in_batch(state.batch) - # forward pass self.engine.run_event(Event.BEFORE_FORWARD) @@ -732,7 +741,7 @@ def _train_batch_inner(self, microbatches: Sequence[Batch]): # Likely need to look into the performance impact if not self.deepspeed_enabled: for loss in ensure_tuple(state.loss): - loss.mul_(last_microbatch_size / current_batch_size) + loss.mul_(state.batch_num_samples / current_batch_size) total_loss += loss.detach().clone() assert state.loss is not None @@ -745,11 +754,11 @@ def _train_batch_inner(self, microbatches: Sequence[Batch]): state.loss = state.scaler.scale(state.loss) if self.deepspeed_enabled: - state.model.backward(state.loss) # type: ignore + cast("deepspeed.DeepSpeedEngine", state.model).backward(state.loss) # This is the same loss scaling and reporting we skipped earlier. for loss in ensure_tuple(state.loss): - loss.mul_(last_microbatch_size / current_batch_size) + loss.mul_(state.batch_num_samples / current_batch_size) total_loss += loss.detach().clone() else: for loss in ensure_tuple(state.loss): @@ -758,7 +767,7 @@ def _train_batch_inner(self, microbatches: Sequence[Batch]): self.engine.run_event(Event.AFTER_BACKWARD) if self.deepspeed_enabled: - state.model.step() # type: ignore + cast("deepspeed.DeepSpeedEngine", state.model).step() # Unscale gradients before `Event.AFTER_TRAIN_BATCH` if use_grad_scaling: @@ -798,7 +807,9 @@ def eval(self, is_batch: bool): for state.batch in itertools.islice(state.eval_dataloader, self._eval_subset_num_batches): state.batch = self.device.batch_to_device(state.batch) - state.batch = state.eval_data.device_transforms(state.batch) + state.batch = self._eval_data_spec.device_transforms(state.batch) + state.batch_num_samples = self._eval_data_spec.get_num_samples_in_batch(state.batch) + state.batch_num_tokens = self._eval_data_spec.get_num_tokens_in_batch(state.batch) if self.deepspeed_enabled: state.batch = fix_batch_precision_for_deepspeed(state.batch, state.precision) diff --git a/composer/utils/data.py b/composer/utils/data.py index 70e1a48710..981f859068 100644 --- a/composer/utils/data.py +++ b/composer/utils/data.py @@ -10,6 +10,7 @@ from torchvision import transforms from composer.core.types import Batch, Dataset, Tensor +from composer.utils.iter_helpers import ensure_tuple class NormalizationFn: @@ -135,3 +136,26 @@ def get_subset_dataset(size: int, dataset: Dataset): raise ValueError(f"The dataset length ({len(dataset)}) is less than the requested size ({size}).") dataset = torch.utils.data.Subset(dataset, list(range(size))) return dataset + + +def get_device_of_batch(batch: Batch) -> torch.device: + """Returns the :class:`torch.device` of the batch. + + Args: + batch (Batch): The batch to determine the device of. + + Returns: + torch.device: The device that the batch is on. + """ + if isinstance(batch, Tensor): + return batch.device + if isinstance(batch, (tuple, list)): # BatchPair + for sample in ensure_tuple(batch): + for x in ensure_tuple(sample): + for tensor in ensure_tuple(x): + return tensor.device + + if isinstance(batch, dict): # BatchDict + for x in batch.values(): + return x.device + raise TypeError(f"Unsupported type for batch: {type(batch)}") diff --git a/composer/utils/iter_helpers.py b/composer/utils/iter_helpers.py index c8ec8e2c75..1e62d0510b 100644 --- a/composer/utils/iter_helpers.py +++ b/composer/utils/iter_helpers.py @@ -29,6 +29,8 @@ def map_collection(collection, map_fn): def ensure_tuple(x): + if x is None: + return tuple() if isinstance(x, tuple): return x if isinstance(x, list): diff --git a/composer/utils/iter_helpers.pyi b/composer/utils/iter_helpers.pyi index c81a7a057e..b0c0b99eae 100644 --- a/composer/utils/iter_helpers.pyi +++ b/composer/utils/iter_helpers.pyi @@ -28,7 +28,24 @@ def map_collection(dict_of_elements: Dict[KT, T], map_fn: Callable[[T], V], /) - def map_collection(singleton: T, map_fn: Callable[[T], V], /) -> V: ... -def ensure_tuple(x: Union[T, Tuple[T, ...], List[T], Dict[Any, T]]) -> Tuple[T, ...]: +@overload +def ensure_tuple(x: None) -> Tuple: + ... + +@overload +def ensure_tuple(x: Tuple[T, ...]) -> Tuple[T, ...]: + ... + +@overload +def ensure_tuple(x: List[T]) -> Tuple[T, ...]: + ... + +@overload +def ensure_tuple(x: Dict[Any, T]) -> Tuple[T, ...]: + ... + +@overload +def ensure_tuple(x: Union[T, None, Tuple[T, ...], List[T], Dict[Any, T]]) -> Tuple[T, ...]: ...