diff --git a/composer/core/state.py b/composer/core/state.py index b778d88d65..4b954e3c8f 100755 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -57,8 +57,6 @@ "batch", "batch_num_samples", "batch_num_tokens", - "microbatches", - "microbatch_idx", "outputs", "train_dataloader", "eval_dataloader", @@ -102,8 +100,6 @@ class State(Serializable): microbatch between :attr:`Event.BATCH_START` and :attr:`Event.BATCH_END`. batch_num_samples (int): The number of samples in the :attr:`batch`. batch_num_tokens (int): The number of tokens in the :attr:`batch`. - microbatches (Sequence[type.Batch]): The batch, split into ``grad_accum`` microbatches. - microbatch_idx (int): The current microbatch index, which will be on [0, ``grad_accum``). loss (types.Tensors): The most recently computed loss. outputs (types.Tensors): The most recently computed output from the model's forward pass. @@ -114,8 +110,6 @@ class State(Serializable): batch: types.Batch batch_num_samples: int batch_num_tokens: int - microbatches: Sequence[types.Batch] - microbatch_idx: int loss: types.Tensors outputs: types.Tensors diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 91a158f81f..4f34f25ffb 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -21,7 +21,7 @@ from composer.core import Callback, DataSpec, Engine, Event, Logger, State, Time from composer.core.algorithm import Algorithm from composer.core.logging import BaseLoggerBackend, LogLevel -from composer.core.types import BreakEpochException, DataLoader, Metrics, Precision +from composer.core.types import Batch, BreakEpochException, DataLoader, Metrics, Precision from composer.datasets.dataloader import DDPDataLoader from composer.loggers.tqdm_logger import TQDMLoggerBackend from composer.models.base import BaseMosaicModel @@ -583,7 +583,6 @@ def _train_loop(self) -> None: state.batch = self._train_data_spec.device_transforms(state.batch) state.batch_num_samples = self._train_data_spec.get_num_samples_in_batch(state.batch) state.batch_num_tokens = self._train_data_spec.get_num_tokens_in_batch(state.batch) - state.microbatches = self._train_data_spec.split_batch(state.batch, state.grad_accum) if self.deepspeed_enabled: state.batch = fix_batch_precision_for_deepspeed(state.batch, state.precision) @@ -593,7 +592,7 @@ def _train_loop(self) -> None: assert train_metrics is not None state.model.eval() with torch.no_grad(): - for eval_microbatch in state.microbatches: + for eval_microbatch in self._train_data_spec.split_batch(state.batch, state.grad_accum): # TODO: Detect if self.run_event(Event.AFTER_DATALOADER) changes the training # data and if so print a warning that metrics may return unexpected results outputs, targets = self.original_model.validate(eval_microbatch) @@ -616,16 +615,18 @@ def _train_loop(self) -> None: "trainer/batch_idx": self.state.timer.batch_in_epoch.value, }) total_loss = None + microbatches = self._train_data_spec.split_batch(state.batch, state.grad_accum) if self.deepspeed_enabled: - total_loss = self._train_batch() + total_loss = self._train_batch(microbatches) elif self._use_closures(): for optimizer in state.optimizers: if use_grad_scaling: - total_loss = state.scaler.step(optimizer, closure=self._train_batch) + total_loss = state.scaler.step(optimizer, + closure=lambda: self._train_batch(microbatches)) else: - total_loss = optimizer.step(closure=lambda: self._train_batch().item()) + total_loss = optimizer.step(closure=lambda: self._train_batch(microbatches).item()) else: - total_loss = self._train_batch() + total_loss = self._train_batch(microbatches) for optimizer in state.optimizers: if use_grad_scaling: state.scaler.step(optimizer) @@ -690,7 +691,7 @@ def _train_loop(self) -> None: self.engine.run_event(Event.TRAINING_END) - def _train_batch(self, ddp_sync: bool = True): + def _train_batch(self, microbatches: Sequence[Batch], ddp_sync: bool = True): """Run training on a full batch of data. Args: @@ -702,12 +703,12 @@ def _train_batch(self, ddp_sync: bool = True): if ddp_sync or not isinstance(self.state.model, DistributedDataParallel): context = contextlib.nullcontext else: - context = self.state.model.no_sync + context = cast(Callable[[], ContextManager], self.state.model.no_sync) - with context(): # type: ignore - Pyright apparently doesn't recognize no_sync - return self._train_batch_inner() + with context(): + return self._train_batch_inner(microbatches) - def _train_batch_inner(self): + def _train_batch_inner(self, microbatches: Sequence[Batch]): """Iterate over microbatches and compute the loss that will be used to step the optimizer. """ @@ -725,13 +726,12 @@ def _train_batch_inner(self): # tracker for gradient accumulation total_loss = self.device.tensor_to_device(torch.zeros(size=(1,))) - current_batch_size = sum( - [self._train_data_spec.get_num_samples_in_batch(batch) for batch in state.microbatches]) + current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches]) - for state.microbatch_idx, state.batch in enumerate(state.microbatches): + for microbatch_idx, state.batch in enumerate(microbatches): state.batch_num_tokens = self._train_data_spec.get_num_tokens_in_batch(state.batch) state.batch_num_samples = self._train_data_spec.get_num_samples_in_batch(state.batch) - is_final_microbatch = state.microbatch_idx + 1 == len(state.microbatches) + is_final_microbatch = microbatch_idx + 1 == len(microbatches) sync_context = contextlib.nullcontext() if self.deepspeed_enabled else ddp_sync_context( state, is_final_microbatch, self.ddp_sync_strategy) with sync_context: