Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions composer/algorithms/seq_length_warmup/seq_length_warmup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright 2021 MosaicML. All Rights Reserved.

import math
from dataclasses import asdict, dataclass
from typing import Dict, Mapping, Optional

Expand All @@ -10,7 +9,7 @@
from composer.algorithms import AlgorithmHparams
from composer.core.types import Algorithm, Batch, Event, Logger, State, Tensor
from composer.models.transformer_shared import MosaicTransformer
from composer.utils import dist, ensure_tuple
from composer.utils import ensure_tuple


def apply_seq_length_warmup(batch: Dict[str, Tensor], curr_seq_len: int, truncate: bool) -> Batch:
Expand Down Expand Up @@ -180,8 +179,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
# all of the parameters
device = next(state.model.parameters()).device

assert (state.train_batch_size % dist.get_world_size()) == 0
per_gpu_batch = math.ceil(state.train_batch_size / (dist.get_world_size() * state.grad_accum))
per_gpu_macrobatch = state.train_dataloader.batch_size
if per_gpu_macrobatch is None:
raise RuntimeError("seq_length_warmup requires constant batch sizing")
assert per_gpu_macrobatch % state.grad_accum == 0, "grad accum should evenly divide the batch"
per_gpu_batch = per_gpu_macrobatch // state.grad_accum

input_ids = torch.randint(low=0,
high=vocab_size - 1,
size=(per_gpu_batch, self.hparams.max_seq_length),
Expand Down
7 changes: 6 additions & 1 deletion composer/callbacks/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import warnings
from typing import Sequence

import torch

from composer.callbacks.callback_hparams import BenchmarkerHparams
from composer.core import Logger, State
from composer.core.callback import Callback
from composer.core.types import BreakEpochException
from composer.utils import dist
from composer.utils.data import get_device_of_batch

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -158,7 +161,9 @@ def batch_end(self, state: State, logger: Logger):
now = time.time()
elapsed = now - self.current_time
self.current_time = now
self.profile_examples += state.last_batch_size * dist.get_world_size()
batch_num_samples = torch.tensor([int(state.batch_num_samples)], device=get_device_of_batch(state.batch))
dist.all_reduce(batch_num_samples, reduce_operation="SUM")
self.profile_examples += int(batch_num_samples.item())
self.profile_steps += 1
self.profile_time += elapsed

Expand Down
14 changes: 6 additions & 8 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from collections import deque
from typing import Deque, Optional

import torch

from composer.callbacks.callback_hparams import SpeedMonitorHparams
from composer.core import Logger, State
from composer.core.callback import RankZeroCallback
from composer.core.types import StateDict
from composer.utils import dist
from composer.utils.data import get_device_of_batch


class SpeedMonitor(RankZeroCallback):
Expand Down Expand Up @@ -78,14 +81,9 @@ def epoch_start(self, state: State, logger: Logger):

def batch_end(self, state: State, logger: Logger):
self.batch_end_times.append(time.time())
batch_num_samples = 0
batch_num_samples += state.last_batch_size
# TODO this is a hack around not having proper syncing / reduction available in callbacks
# Ideally, callbacks would have a way of reducing tensors.
# It assumes that each process has equal batch sizing
# For the speed monitor, we might be able to use the static step converter with num_samples
batch_num_samples *= dist.get_world_size()
self.batch_num_samples.append(batch_num_samples)
batch_num_samples = torch.tensor([int(state.batch_num_samples)], device=get_device_of_batch(state.batch))
dist.all_reduce(batch_num_samples, reduce_operation="SUM")
self.batch_num_samples.append(int(batch_num_samples.item()))
self.train_examples_per_epoch += batch_num_samples
if len(self.batch_end_times) == self.hparams.window_size + 1:
throughput = sum(self.batch_num_samples) / (self.batch_end_times[-1] - self.batch_end_times[0])
Expand Down
82 changes: 26 additions & 56 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
import logging
import textwrap
import warnings
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Optional, Sequence, Union, cast
from typing import TYPE_CHECKING, Callable, ContextManager, Optional, Sequence, Union, cast

import torch
import torch.nn.modules.utils
from torch.nn.parallel import DistributedDataParallel

import composer.core.types as types
from composer.core.data_spec import DataSpec
from composer.core.precision import Precision
from composer.core.serializable import Serializable
from composer.core.time import Time, Timer, TimeUnit
from composer.utils import dist, ensure_tuple
from composer.utils import ensure_tuple
from composer.utils.precision import default_precision_factory

if TYPE_CHECKING:
Expand Down Expand Up @@ -55,9 +54,13 @@
SKIP_SERIALIZATION_FIELDS = [
"loss",
"batch",
"batch_num_samples",
"batch_num_tokens",
"microbatches",
"microbatch_idx",
"outputs",
"train_data",
"eval_data",
"train_dataloader",
"eval_dataloader",
"_steps_per_epoch",
"_precision_context",
]
Expand Down Expand Up @@ -91,14 +94,26 @@ class State(Serializable):
callbacks (Sequence[Callback]): The callbacks used for training.

Attributes:
batch (types.Batch): The most recently retrieved batch.
last_batch_size (int): The size of the batch last returned from the dataloader. This can be different from the current size of ``batch`` if algorithms have modified the ``batch``.
batch (types.Batch): The batch. This will be the entire batch during the :attr:`Event.AFTER_DATALOADER`, or a
microbatch between :attr:`Event.BATCH_START` and :attr:`Event.BATCH_END`.
batch_num_samples (int): The number of samples in the :attr:`batch`.
batch_num_tokens (int): The number of tokens in the :attr:`batch`.
microbatches (Sequence[type.Batch]): The batch, split into ``grad_accum`` microbatches.
microbatch_idx (int): The current microbatch index, which will be on [0, ``grad_accum``).

loss (types.Tensors): The most recently computed loss.
outputs (types.Tensors): The most recently computed output from the model's forward pass.
timer (types.Timer): The timer that tracks training loop progress.
"""

_max_duration: Time[int]
batch: types.Batch
batch_num_samples: int
batch_num_tokens: int
microbatches: Sequence[types.Batch]
microbatch_idx: int
loss: types.Tensors
outputs: types.Tensors

def __init__(
self,
Expand All @@ -107,8 +122,8 @@ def __init__(

# data configurations
grad_accum: int,
train_dataloader: Union[types.DataLoader, types.DataSpec, Dict[str, Any]],
eval_dataloader: Union[types.DataLoader, types.DataSpec, Dict[str, Any]],
train_dataloader: types.DataLoader,
eval_dataloader: types.DataLoader,

# stopping conditions
max_duration: Union[str, Time[int]],
Expand All @@ -130,30 +145,15 @@ def __init__(
):
self.model = model
self.grad_accum = grad_accum
if not isinstance(train_dataloader, DataSpec):
if isinstance(train_dataloader, dict):
train_dataloader = DataSpec(**train_dataloader)
else:
train_dataloader = DataSpec(train_dataloader)
self.train_data = train_dataloader
if not isinstance(eval_dataloader, DataSpec):
if isinstance(eval_dataloader, dict):
eval_dataloader = DataSpec(**eval_dataloader)
else:
eval_dataloader = DataSpec(eval_dataloader)
self.eval_data = eval_dataloader
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.max_duration = max_duration

self.timer = Timer()
self._precision = Precision(precision)
self._steps_per_epoch = None
self._precision_context = precision_context

self.batch: types.Batch = {}
self.last_batch_size = 0
self.loss: types.Tensors = torch.zeros(size=(1,))
self.outputs: types.Tensors = torch.zeros(size=(1,))

if optimizers is None:
self._optimizers = []
else:
Expand Down Expand Up @@ -212,22 +212,6 @@ def max_epochs(self):
assert self.max_duration.unit == TimeUnit.EPOCH, "invariant violation -- max duration must be epochs for now"
return self.max_duration.value

@property
def train_dataloader(self):
return self.train_data.dataloader

@train_dataloader.setter
def train_dataloader(self, dataloader: types.DataLoader):
self.train_data = DataSpec(dataloader)

@property
def eval_dataloader(self):
return self.eval_data.dataloader

@eval_dataloader.setter
def eval_dataloader(self, dataloader: types.DataLoader):
self.eval_data = DataSpec(dataloader)

@property
def optimizers(self):
return self._optimizers
Expand Down Expand Up @@ -260,20 +244,6 @@ def algorithms(self):
def algorithms(self, algorithms: Sequence[Algorithm]):
self._algorithms[:] = algorithms

@property
def train_batch_size(self):
"""The global batch size used for training."""
if self.train_dataloader.batch_size is None:
raise RuntimeError("train dataloader batch size is undefined")
return self.train_dataloader.batch_size * dist.get_world_size()

@property
def eval_batch_size(self):
"""The batch size used for evaluation."""
if self.eval_dataloader.batch_size is None:
raise RuntimeError("eval dataloader batch size is undefined")
return self.eval_dataloader.batch_size * dist.get_world_size()

def state_dict(self) -> types.StateDict:
"""Returns the state as a :class:`dict`."""
state_dict: types.StateDict = {}
Expand Down
4 changes: 3 additions & 1 deletion composer/trainer/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ def validate(self):
warnings.warn("ZeRO stage 3 is largely untested with composer. Certain algorithms may break.")

def initialize_object(self, state: State, grad_clip_norm: Optional[float]):
if state.train_dataloader.batch_size is None:
raise RuntimeError("Deepspeed requires a dataloader with a known batch size")

deepspeed_config: dict[str, Any] = {
"train_batch_size": state.train_batch_size,
"train_micro_batch_size_per_gpu": state.train_dataloader.batch_size // state.grad_accum,
"gradient_accumulation_steps": state.grad_accum,
"zero_optimization": {
"stage": self.zero_stage,
Expand Down
Loading