21
21
from composer .core import Callback , DataSpec , Engine , Event , Logger , State , Time
22
22
from composer .core .algorithm import Algorithm
23
23
from composer .core .logging import BaseLoggerBackend , LogLevel
24
- from composer .core .types import BreakEpochException , DataLoader , Metrics , Precision
24
+ from composer .core .types import Batch , BreakEpochException , DataLoader , Metrics , Precision
25
25
from composer .datasets .dataloader import DDPDataLoader
26
26
from composer .loggers .tqdm_logger import TQDMLoggerBackend
27
27
from composer .models .base import BaseMosaicModel
@@ -583,7 +583,6 @@ def _train_loop(self) -> None:
583
583
state .batch = self ._train_data_spec .device_transforms (state .batch )
584
584
state .batch_num_samples = self ._train_data_spec .get_num_samples_in_batch (state .batch )
585
585
state .batch_num_tokens = self ._train_data_spec .get_num_tokens_in_batch (state .batch )
586
- state .microbatches = self ._train_data_spec .split_batch (state .batch , state .grad_accum )
587
586
588
587
if self .deepspeed_enabled :
589
588
state .batch = fix_batch_precision_for_deepspeed (state .batch , state .precision )
@@ -593,7 +592,7 @@ def _train_loop(self) -> None:
593
592
assert train_metrics is not None
594
593
state .model .eval ()
595
594
with torch .no_grad ():
596
- for eval_microbatch in state .microbatches :
595
+ for eval_microbatch in self . _train_data_spec . split_batch ( state .batch , state . grad_accum ) :
597
596
# TODO: Detect if self.run_event(Event.AFTER_DATALOADER) changes the training
598
597
# data and if so print a warning that metrics may return unexpected results
599
598
outputs , targets = self .original_model .validate (eval_microbatch )
@@ -616,16 +615,18 @@ def _train_loop(self) -> None:
616
615
"trainer/batch_idx" : self .state .timer .batch_in_epoch .value ,
617
616
})
618
617
total_loss = None
618
+ microbatches = self ._train_data_spec .split_batch (state .batch , state .grad_accum )
619
619
if self .deepspeed_enabled :
620
- total_loss = self ._train_batch ()
620
+ total_loss = self ._train_batch (microbatches )
621
621
elif self ._use_closures ():
622
622
for optimizer in state .optimizers :
623
623
if use_grad_scaling :
624
- total_loss = state .scaler .step (optimizer , closure = self ._train_batch )
624
+ total_loss = state .scaler .step (optimizer ,
625
+ closure = lambda : self ._train_batch (microbatches ))
625
626
else :
626
- total_loss = optimizer .step (closure = lambda : self ._train_batch ().item ())
627
+ total_loss = optimizer .step (closure = lambda : self ._train_batch (microbatches ).item ())
627
628
else :
628
- total_loss = self ._train_batch ()
629
+ total_loss = self ._train_batch (microbatches )
629
630
for optimizer in state .optimizers :
630
631
if use_grad_scaling :
631
632
state .scaler .step (optimizer )
@@ -690,7 +691,7 @@ def _train_loop(self) -> None:
690
691
691
692
self .engine .run_event (Event .TRAINING_END )
692
693
693
- def _train_batch (self , ddp_sync : bool = True ):
694
+ def _train_batch (self , microbatches : Sequence [ Batch ], ddp_sync : bool = True ):
694
695
"""Run training on a full batch of data.
695
696
696
697
Args:
@@ -702,12 +703,12 @@ def _train_batch(self, ddp_sync: bool = True):
702
703
if ddp_sync or not isinstance (self .state .model , DistributedDataParallel ):
703
704
context = contextlib .nullcontext
704
705
else :
705
- context = self .state .model .no_sync
706
+ context = cast ( Callable [[], ContextManager ], self .state .model .no_sync )
706
707
707
- with context (): # type: ignore - Pyright apparently doesn't recognize no_sync
708
- return self ._train_batch_inner ()
708
+ with context ():
709
+ return self ._train_batch_inner (microbatches )
709
710
710
- def _train_batch_inner (self ):
711
+ def _train_batch_inner (self , microbatches : Sequence [ Batch ] ):
711
712
"""Iterate over microbatches and compute the loss that will be used to step
712
713
the optimizer.
713
714
"""
@@ -725,13 +726,12 @@ def _train_batch_inner(self):
725
726
726
727
# tracker for gradient accumulation
727
728
total_loss = self .device .tensor_to_device (torch .zeros (size = (1 ,)))
728
- current_batch_size = sum (
729
- [self ._train_data_spec .get_num_samples_in_batch (batch ) for batch in state .microbatches ])
729
+ current_batch_size = sum ([self ._train_data_spec .get_num_samples_in_batch (batch ) for batch in microbatches ])
730
730
731
- for state . microbatch_idx , state .batch in enumerate (state . microbatches ):
731
+ for microbatch_idx , state .batch in enumerate (microbatches ):
732
732
state .batch_num_tokens = self ._train_data_spec .get_num_tokens_in_batch (state .batch )
733
733
state .batch_num_samples = self ._train_data_spec .get_num_samples_in_batch (state .batch )
734
- is_final_microbatch = state . microbatch_idx + 1 == len (state . microbatches )
734
+ is_final_microbatch = microbatch_idx + 1 == len (microbatches )
735
735
sync_context = contextlib .nullcontext () if self .deepspeed_enabled else ddp_sync_context (
736
736
state , is_final_microbatch , self .ddp_sync_strategy )
737
737
with sync_context :
0 commit comments