Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@
"batch",
"batch_num_samples",
"batch_num_tokens",
"microbatches",
"microbatch_idx",
"outputs",
"train_dataloader",
"eval_dataloader",
Expand Down Expand Up @@ -102,8 +100,6 @@ class State(Serializable):
microbatch between :attr:`Event.BATCH_START` and :attr:`Event.BATCH_END`.
batch_num_samples (int): The number of samples in the :attr:`batch`.
batch_num_tokens (int): The number of tokens in the :attr:`batch`.
microbatches (Sequence[type.Batch]): The batch, split into ``grad_accum`` microbatches.
microbatch_idx (int): The current microbatch index, which will be on [0, ``grad_accum``).

loss (types.Tensors): The most recently computed loss.
outputs (types.Tensors): The most recently computed output from the model's forward pass.
Expand All @@ -114,8 +110,6 @@ class State(Serializable):
batch: types.Batch
batch_num_samples: int
batch_num_tokens: int
microbatches: Sequence[types.Batch]
microbatch_idx: int
loss: types.Tensors
outputs: types.Tensors

Expand Down
32 changes: 16 additions & 16 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from composer.core import Callback, DataSpec, Engine, Event, Logger, State, Time
from composer.core.algorithm import Algorithm
from composer.core.logging import BaseLoggerBackend, LogLevel
from composer.core.types import BreakEpochException, DataLoader, Metrics, Precision
from composer.core.types import Batch, BreakEpochException, DataLoader, Metrics, Precision
from composer.datasets.dataloader import DDPDataLoader
from composer.loggers.tqdm_logger import TQDMLoggerBackend
from composer.models.base import BaseMosaicModel
Expand Down Expand Up @@ -583,7 +583,6 @@ def _train_loop(self) -> None:
state.batch = self._train_data_spec.device_transforms(state.batch)
state.batch_num_samples = self._train_data_spec.get_num_samples_in_batch(state.batch)
state.batch_num_tokens = self._train_data_spec.get_num_tokens_in_batch(state.batch)
state.microbatches = self._train_data_spec.split_batch(state.batch, state.grad_accum)

if self.deepspeed_enabled:
state.batch = fix_batch_precision_for_deepspeed(state.batch, state.precision)
Expand All @@ -593,7 +592,7 @@ def _train_loop(self) -> None:
assert train_metrics is not None
state.model.eval()
with torch.no_grad():
for eval_microbatch in state.microbatches:
for eval_microbatch in self._train_data_spec.split_batch(state.batch, state.grad_accum):
# TODO: Detect if self.run_event(Event.AFTER_DATALOADER) changes the training
# data and if so print a warning that metrics may return unexpected results
outputs, targets = self.original_model.validate(eval_microbatch)
Expand All @@ -616,16 +615,18 @@ def _train_loop(self) -> None:
"trainer/batch_idx": self.state.timer.batch_in_epoch.value,
})
total_loss = None
microbatches = self._train_data_spec.split_batch(state.batch, state.grad_accum)
if self.deepspeed_enabled:
total_loss = self._train_batch()
total_loss = self._train_batch(microbatches)
elif self._use_closures():
for optimizer in state.optimizers:
if use_grad_scaling:
total_loss = state.scaler.step(optimizer, closure=self._train_batch)
total_loss = state.scaler.step(optimizer,
closure=lambda: self._train_batch(microbatches))
else:
total_loss = optimizer.step(closure=lambda: self._train_batch().item())
total_loss = optimizer.step(closure=lambda: self._train_batch(microbatches).item())
else:
total_loss = self._train_batch()
total_loss = self._train_batch(microbatches)
for optimizer in state.optimizers:
if use_grad_scaling:
state.scaler.step(optimizer)
Expand Down Expand Up @@ -690,7 +691,7 @@ def _train_loop(self) -> None:

self.engine.run_event(Event.TRAINING_END)

def _train_batch(self, ddp_sync: bool = True):
def _train_batch(self, microbatches: Sequence[Batch], ddp_sync: bool = True):
"""Run training on a full batch of data.

Args:
Expand All @@ -702,12 +703,12 @@ def _train_batch(self, ddp_sync: bool = True):
if ddp_sync or not isinstance(self.state.model, DistributedDataParallel):
context = contextlib.nullcontext
else:
context = self.state.model.no_sync
context = cast(Callable[[], ContextManager], self.state.model.no_sync)

with context(): # type: ignore - Pyright apparently doesn't recognize no_sync
return self._train_batch_inner()
with context():
return self._train_batch_inner(microbatches)

def _train_batch_inner(self):
def _train_batch_inner(self, microbatches: Sequence[Batch]):
"""Iterate over microbatches and compute the loss that will be used to step
the optimizer.
"""
Expand All @@ -725,13 +726,12 @@ def _train_batch_inner(self):

# tracker for gradient accumulation
total_loss = self.device.tensor_to_device(torch.zeros(size=(1,)))
current_batch_size = sum(
[self._train_data_spec.get_num_samples_in_batch(batch) for batch in state.microbatches])
current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches])

for state.microbatch_idx, state.batch in enumerate(state.microbatches):
for microbatch_idx, state.batch in enumerate(microbatches):
state.batch_num_tokens = self._train_data_spec.get_num_tokens_in_batch(state.batch)
state.batch_num_samples = self._train_data_spec.get_num_samples_in_batch(state.batch)
is_final_microbatch = state.microbatch_idx + 1 == len(state.microbatches)
is_final_microbatch = microbatch_idx + 1 == len(microbatches)
sync_context = contextlib.nullcontext() if self.deepspeed_enabled else ddp_sync_context(
state, is_final_microbatch, self.ddp_sync_strategy)
with sync_context:
Expand Down