30
30
from composer .optim import (ComposedScheduler , CosineAnnealingLRHparams , DecoupledSGDWHparams , OptimizerHparams ,
31
31
SchedulerHparams , WarmUpLRHparams )
32
32
from composer .optim .scheduler import ensure_warmup_last
33
- from composer .trainer .checkpoint import Checkpointer , CheckpointLoader
33
+ from composer .trainer .checkpoint import CheckpointLoader , CheckpointSaver
34
34
from composer .trainer .deepspeed import DeepSpeedHparams
35
35
from composer .trainer .devices .device import Device
36
36
from composer .trainer .devices .device_cpu import DeviceCPU
@@ -93,14 +93,15 @@ class Trainer:
93
93
log_destinations (List[BaseLoggerBackend], optional): The destinations to log training information to.
94
94
(default ``[TQDMLoggerBackend()]``).
95
95
callbacks (Sequence[Callback], optional): The callbacks to run during training. (default: ``[]``)
96
- checkpoint_filepath (str, optional): The path to a trainer checkpoint file. If provided
97
- the trainer will load the state (along with it's associated attributes) during initialization.
98
- (default: ``None``)
99
- checkpoint_interval_unit (int, optional): Unit for the checkpoint save interval -- should be 'ep'
100
- for epochs, 'it' for iterations, or None to disable checkpointing. (default: ``None``).
101
- checkpoint_folder (str, optional): The folder to save checkpoints to. Relative to `os.environ.get('RUN_DIRECTORY', '.')`,
102
- (default: ``checkpoints``)
103
- checkpoint_interval (int, optional): The frequency with which to checkpoint. (default: ``1``)
96
+ checkpoint_filepath (str): For loading checkpoints, the path to an existing checkpoint file.
97
+ load_weights_only (bool): Whether to only restore the weights from the checkpoint without
98
+ restoring the associated state.
99
+ strict_model_weights (bool, optional): Whether to force that the checkpointed weights must exactly
100
+ match the model weights.
101
+ checkpoint_folder (str): The path to store checkpoints in.
102
+ checkpoint_interval (int): The amount of time units to wait between creating checkpoints.
103
+ checkpoint_interval_unit (str, optional): The unit (`"ep"` or `"it"`) that
104
+ `checkpoint_interval` should be measured in. Set to ``None`` disables checkpointing. (default: ``None``)
104
105
train_subset_num_batches (int, optional): If specified, finish every epoch early after training
105
106
on this many batches. This parameter has no effect if it is greater than ``len(train_dataloader)``.
106
107
If None (the default), then the entire dataloader will be iterated over.
@@ -150,11 +151,15 @@ def __init__(
150
151
log_destinations : Optional [List [BaseLoggerBackend ]] = None ,
151
152
callbacks : Sequence [Callback ] = tuple (),
152
153
153
- # Checkpoint hparams
154
+ # Checkpoint loading hparams
154
155
checkpoint_filepath : Optional [str ] = None ,
156
+ checkpoint_load_weights_only : bool = False ,
157
+ checkpoint_strict_model_weights : bool = False ,
158
+
159
+ # Checkpoint saving hparams
155
160
checkpoint_interval_unit : Optional [str ] = None ,
156
- checkpoint_folder : Optional [str ] = "checkpoints" ,
157
- checkpoint_interval : Optional [ int ] = 1 ,
161
+ checkpoint_interval : Optional [int ] = None ,
162
+ checkpoint_folder : str = "checkpoints" ,
158
163
159
164
# Subset parameters
160
165
train_subset_num_batches : Optional [int ] = None ,
@@ -295,21 +300,26 @@ def __init__(
295
300
self .state .optimizers = optimizer
296
301
self .state .schedulers = ComposedScheduler (schedulers = schedulers )
297
302
298
- self .checkpointer = None
299
303
# TODO(#121): get checkpointing working with DeepSpeed.
300
- if checkpoint_folder and checkpoint_interval and checkpoint_interval_unit :
304
+ self .checkpoint_saver = None
305
+ if checkpoint_interval is not None and checkpoint_interval_unit is not None :
306
+ self .checkpoint_saver = CheckpointSaver (checkpoint_interval_unit = checkpoint_interval_unit ,
307
+ checkpoint_interval = checkpoint_interval ,
308
+ checkpoint_folder = get_relative_to_run_directory (checkpoint_folder ))
309
+
301
310
if self .deepspeed_enabled :
302
311
raise NotImplementedError ("Checkpointing is not yet supported with DeepSpeed." )
303
- self .checkpointer = Checkpointer (checkpoint_folder = get_relative_to_run_directory (checkpoint_folder ),
304
- checkpoint_interval = checkpoint_interval ,
305
- checkpoint_interval_unit = checkpoint_interval_unit )
306
312
307
- self .checkpoint_loader = None
308
313
# TODO(#121): get checkpointing working with DeepSpeed.
309
- if checkpoint_filepath :
314
+ self .checkpoint_loader = None
315
+ if checkpoint_filepath is not None :
310
316
if self .deepspeed_enabled :
311
317
raise NotImplementedError ("Checkpointing is not yet supported with DeepSpeed." )
312
- self .checkpoint_loader = CheckpointLoader (checkpoint_filepath = checkpoint_filepath )
318
+
319
+ self .checkpoint_loader = CheckpointLoader (checkpoint_filepath = checkpoint_filepath ,
320
+ load_weights_only = checkpoint_load_weights_only ,
321
+ strict_model_weights = checkpoint_strict_model_weights )
322
+
313
323
restored_seed = self .checkpoint_loader .load_checkpoint (state = self .state )
314
324
# Set the restored seed so that the correct seed will be saved in future checkpoints
315
325
# Used to handle the case where another checkpoint is saved after resuming from checkpoint.
@@ -368,6 +378,19 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer:
368
378
each evaluation epoch may load a different subset of samples.""" ))
369
379
eval_dataloader = hparams .val_dataset .initialize_object (eval_device_batch_size , hparams .dataloader )
370
380
381
+ # Checkpoint loading hparams
382
+ checkpoint_filepath = hparams .load_checkpoint .filepath if hparams .load_checkpoint is not None else None
383
+ checkpoint_load_weights_only = hparams .load_checkpoint .load_weights_only \
384
+ if hparams .load_checkpoint is not None else False
385
+ checkpoint_strict_model_weights = hparams .load_checkpoint .strict_model_weights \
386
+ if hparams .load_checkpoint is not None else False
387
+
388
+ # Checkpoint saving hparams
389
+ checkpoint_interval_unit = hparams .save_checkpoint .interval_unit \
390
+ if hparams .save_checkpoint is not None else None
391
+ checkpoint_interval = hparams .save_checkpoint .interval if hparams .save_checkpoint is not None else None
392
+ checkpoint_folder = hparams .save_checkpoint .folder if hparams .save_checkpoint is not None else "checkpoints"
393
+
371
394
trainer = cls (
372
395
model = model ,
373
396
train_dataloader = train_dataloader ,
@@ -400,11 +423,15 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer:
400
423
log_destinations = log_destinations ,
401
424
callbacks = tuple (callbacks ),
402
425
403
- # Checkpointing hparams
404
- checkpoint_filepath = hparams .checkpoint_filepath ,
405
- checkpoint_interval_unit = hparams .checkpoint_interval_unit ,
406
- checkpoint_folder = hparams .checkpoint_folder ,
407
- checkpoint_interval = hparams .checkpoint_interval ,
426
+ # Checkpoint loading hparams
427
+ checkpoint_filepath = checkpoint_filepath ,
428
+ checkpoint_load_weights_only = checkpoint_load_weights_only ,
429
+ checkpoint_strict_model_weights = checkpoint_strict_model_weights ,
430
+
431
+ # Checkpoint saving hparams
432
+ checkpoint_interval_unit = checkpoint_interval_unit ,
433
+ checkpoint_interval = checkpoint_interval ,
434
+ checkpoint_folder = checkpoint_folder ,
408
435
409
436
# Subset parameters
410
437
train_subset_num_batches = hparams .train_subset_num_batches ,
@@ -674,11 +701,12 @@ def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor:
674
701
self .eval (is_batch = True )
675
702
676
703
state .step += 1
677
- if self .checkpointer and self .checkpointer .should_checkpoint (state = state , event = Event .BATCH_END ):
678
- self .checkpointer .save_checkpoint (state = state ,
679
- seed = self .seed ,
680
- device = self .device ,
681
- config = self .config )
704
+ if self .checkpoint_saver and self .checkpoint_saver .should_checkpoint (state = state ,
705
+ event = Event .BATCH_END ):
706
+ self .checkpoint_saver .save_checkpoint (state = state ,
707
+ seed = self .seed ,
708
+ device = self .device ,
709
+ config = self .config )
682
710
except BreakEpochException :
683
711
log .info (f'Skipping the rest of Epoch { state .epoch } ' )
684
712
@@ -692,8 +720,11 @@ def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor:
692
720
693
721
state .epoch += 1
694
722
695
- if self .checkpointer and self .checkpointer .should_checkpoint (state = state , event = Event .EPOCH_END ):
696
- self .checkpointer .save_checkpoint (state = state , seed = self .seed , device = self .device , config = self .config )
723
+ if self .checkpoint_saver and self .checkpoint_saver .should_checkpoint (state = state , event = Event .EPOCH_END ):
724
+ self .checkpoint_saver .save_checkpoint (state = state ,
725
+ seed = self .seed ,
726
+ device = self .device ,
727
+ config = self .config )
697
728
698
729
self .engine .run_event (Event .TRAINING_END )
699
730
0 commit comments