Skip to content

Commit f135383

Browse files
Fix bug with AFTER_DATALOADER event; remove microbatches from state (#258)
1. #223 introduced a bug where algorithms that run on the AFTER_DATALOADER and did a not-in-place modification of state.batch did not also update state.microbatches (which was used for training), so these algorithms were effectively ignored. Fixed this bug by computing the microbataches AFTER the Event.AFTER_DATALOADER event. 2. Removed the `microbatches` and `microbatch_idx` from the state. Instead, algorithms that need to run on smaller batch sizes should use the Event.BATCH_START event instead of Event.AFTER_DATALOADER, since Event.BATCH_START will get the forward-pass sized batch.
1 parent cc5e8d9 commit f135383

File tree

2 files changed

+16
-22
lines changed

2 files changed

+16
-22
lines changed

composer/core/state.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@
5757
"batch",
5858
"batch_num_samples",
5959
"batch_num_tokens",
60-
"microbatches",
61-
"microbatch_idx",
6260
"outputs",
6361
"train_dataloader",
6462
"eval_dataloader",
@@ -102,8 +100,6 @@ class State(Serializable):
102100
microbatch between :attr:`Event.BATCH_START` and :attr:`Event.BATCH_END`.
103101
batch_num_samples (int): The number of samples in the :attr:`batch`.
104102
batch_num_tokens (int): The number of tokens in the :attr:`batch`.
105-
microbatches (Sequence[type.Batch]): The batch, split into ``grad_accum`` microbatches.
106-
microbatch_idx (int): The current microbatch index, which will be on [0, ``grad_accum``).
107103
108104
loss (types.Tensors): The most recently computed loss.
109105
outputs (types.Tensors): The most recently computed output from the model's forward pass.
@@ -114,8 +110,6 @@ class State(Serializable):
114110
batch: types.Batch
115111
batch_num_samples: int
116112
batch_num_tokens: int
117-
microbatches: Sequence[types.Batch]
118-
microbatch_idx: int
119113
loss: types.Tensors
120114
outputs: types.Tensors
121115

composer/trainer/trainer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from composer.core import Callback, DataSpec, Engine, Event, Logger, State, Time
2222
from composer.core.algorithm import Algorithm
2323
from composer.core.logging import BaseLoggerBackend, LogLevel
24-
from composer.core.types import BreakEpochException, DataLoader, Metrics, Precision
24+
from composer.core.types import Batch, BreakEpochException, DataLoader, Metrics, Precision
2525
from composer.datasets.dataloader import DDPDataLoader
2626
from composer.loggers.tqdm_logger import TQDMLoggerBackend
2727
from composer.models.base import BaseMosaicModel
@@ -583,7 +583,6 @@ def _train_loop(self) -> None:
583583
state.batch = self._train_data_spec.device_transforms(state.batch)
584584
state.batch_num_samples = self._train_data_spec.get_num_samples_in_batch(state.batch)
585585
state.batch_num_tokens = self._train_data_spec.get_num_tokens_in_batch(state.batch)
586-
state.microbatches = self._train_data_spec.split_batch(state.batch, state.grad_accum)
587586

588587
if self.deepspeed_enabled:
589588
state.batch = fix_batch_precision_for_deepspeed(state.batch, state.precision)
@@ -593,7 +592,7 @@ def _train_loop(self) -> None:
593592
assert train_metrics is not None
594593
state.model.eval()
595594
with torch.no_grad():
596-
for eval_microbatch in state.microbatches:
595+
for eval_microbatch in self._train_data_spec.split_batch(state.batch, state.grad_accum):
597596
# TODO: Detect if self.run_event(Event.AFTER_DATALOADER) changes the training
598597
# data and if so print a warning that metrics may return unexpected results
599598
outputs, targets = self.original_model.validate(eval_microbatch)
@@ -616,16 +615,18 @@ def _train_loop(self) -> None:
616615
"trainer/batch_idx": self.state.timer.batch_in_epoch.value,
617616
})
618617
total_loss = None
618+
microbatches = self._train_data_spec.split_batch(state.batch, state.grad_accum)
619619
if self.deepspeed_enabled:
620-
total_loss = self._train_batch()
620+
total_loss = self._train_batch(microbatches)
621621
elif self._use_closures():
622622
for optimizer in state.optimizers:
623623
if use_grad_scaling:
624-
total_loss = state.scaler.step(optimizer, closure=self._train_batch)
624+
total_loss = state.scaler.step(optimizer,
625+
closure=lambda: self._train_batch(microbatches))
625626
else:
626-
total_loss = optimizer.step(closure=lambda: self._train_batch().item())
627+
total_loss = optimizer.step(closure=lambda: self._train_batch(microbatches).item())
627628
else:
628-
total_loss = self._train_batch()
629+
total_loss = self._train_batch(microbatches)
629630
for optimizer in state.optimizers:
630631
if use_grad_scaling:
631632
state.scaler.step(optimizer)
@@ -690,7 +691,7 @@ def _train_loop(self) -> None:
690691

691692
self.engine.run_event(Event.TRAINING_END)
692693

693-
def _train_batch(self, ddp_sync: bool = True):
694+
def _train_batch(self, microbatches: Sequence[Batch], ddp_sync: bool = True):
694695
"""Run training on a full batch of data.
695696
696697
Args:
@@ -702,12 +703,12 @@ def _train_batch(self, ddp_sync: bool = True):
702703
if ddp_sync or not isinstance(self.state.model, DistributedDataParallel):
703704
context = contextlib.nullcontext
704705
else:
705-
context = self.state.model.no_sync
706+
context = cast(Callable[[], ContextManager], self.state.model.no_sync)
706707

707-
with context(): # type: ignore - Pyright apparently doesn't recognize no_sync
708-
return self._train_batch_inner()
708+
with context():
709+
return self._train_batch_inner(microbatches)
709710

710-
def _train_batch_inner(self):
711+
def _train_batch_inner(self, microbatches: Sequence[Batch]):
711712
"""Iterate over microbatches and compute the loss that will be used to step
712713
the optimizer.
713714
"""
@@ -725,13 +726,12 @@ def _train_batch_inner(self):
725726

726727
# tracker for gradient accumulation
727728
total_loss = self.device.tensor_to_device(torch.zeros(size=(1,)))
728-
current_batch_size = sum(
729-
[self._train_data_spec.get_num_samples_in_batch(batch) for batch in state.microbatches])
729+
current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches])
730730

731-
for state.microbatch_idx, state.batch in enumerate(state.microbatches):
731+
for microbatch_idx, state.batch in enumerate(microbatches):
732732
state.batch_num_tokens = self._train_data_spec.get_num_tokens_in_batch(state.batch)
733733
state.batch_num_samples = self._train_data_spec.get_num_samples_in_batch(state.batch)
734-
is_final_microbatch = state.microbatch_idx + 1 == len(state.microbatches)
734+
is_final_microbatch = microbatch_idx + 1 == len(microbatches)
735735
sync_context = contextlib.nullcontext() if self.deepspeed_enabled else ddp_sync_context(
736736
state, is_final_microbatch, self.ddp_sync_strategy)
737737
with sync_context:

0 commit comments

Comments
 (0)