Skip to content

Commit 2b25192

Browse files
moinnadeemMoin Nadeemravi-mosaicml
authored
Add the ability to load a checkpoint without restoring state (#169)
* fixing checkpoint bug * finalizing fine-tuning a checkpointed model * address PR feedback * adding save_checkpoint and load_checkpoint hparams interface * yapf & pyright * changing interface * everyone always asks 'what is yapf', but never 'how is yapf'? * renaming Checkpointer -> CheckpointSaver * renaming Checkpointer -> CheckpointSaver * addressing feedback & friendly renaming * addressing pyright * yapf * adding tests * moving commits to BERT branch * changing folder to be relative to run dir * adding tests * pyright part 1 * pyright on trainer file * moving restoring RNG & random seed to else clause * Fix tests * Addressed comments Co-authored-by: Moin Nadeem <[email protected]> Co-authored-by: Ravi Rahman <[email protected]>
1 parent ba56b2c commit 2b25192

File tree

8 files changed

+309
-76
lines changed

8 files changed

+309
-76
lines changed

composer/core/state.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,24 @@ def state_dict(self) -> types.StateDict:
213213
state_dict["_is_model_ddp_wrapped"] = isinstance(self.model, DistributedDataParallel)
214214
return state_dict
215215

216-
def load_state_dict(self, state: types.StateDict):
216+
def load_model_state(self, state_dict: types.StateDict, strict: bool):
217+
"""
218+
Loads the model's state from a state_dict.
219+
220+
Args:
221+
state_dict (types.StateDict): object returned from call to :meth:`state_dict`.
222+
strict (bool): whether the keys in the state_dict should perfectly match the keys in the model.
223+
"""
224+
if state_dict["_is_model_ddp_wrapped"] and not isinstance(self.model, DistributedDataParallel):
225+
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], "module.")
226+
227+
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
228+
if len(missing_keys) > 0:
229+
logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
230+
if len(unexpected_keys) > 0:
231+
logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
232+
233+
def load_state_dict(self, state: types.StateDict, strict: bool = False):
217234
"""Loads the state.
218235
219236
Args:
@@ -229,9 +246,7 @@ def load_state_dict(self, state: types.StateDict):
229246
serialized_value = state[state_field_name]
230247

231248
if state_field_name == "model":
232-
if state["_is_model_ddp_wrapped"] and not isinstance(self.model, DistributedDataParallel):
233-
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(serialized_value, "module.")
234-
state_field_value.load_state_dict(serialized_value)
249+
self.load_model_state(state, strict=strict)
235250
else:
236251
for target in ensure_tuple(state_field_value):
237252
if target is None:

composer/datasets/hparams.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _split_fn(batch: Batch, n_microbatches: int) -> List[Batch]:
4040
class DataloaderSpec(NamedTuple):
4141
"""Specification for initializing a dataloader when a device transformation function or split function
4242
is required
43-
43+
4444
Parameters:
4545
dataloader (DataLoader): The initialized dataloader.
4646
device_transform_fn (TDeviceTransformFn, optional):
@@ -104,12 +104,12 @@ class DatasetHparams(hp.Hparams, abc.ABC, metaclass=metaclass):
104104
def initialize_object(self, batch_size: int,
105105
dataloader_hparams: DataloaderHparams) -> Union[DataLoader, DataloaderSpec]:
106106
"""Creates a :class:`DataLoader` or :class:`DataloaderSpec` for this dataset.
107-
107+
108108
Parameters:
109109
batch_size (int): The size of the batch the dataloader should yield. This batch size is
110110
device-specific and already incorporates the world size.
111111
dataloader_hparams (DataloaderHparams): The dataset-independent hparams for the dataloader
112-
112+
113113
Returns:
114114
Dataloader or DataloaderSpec: The dataloader, or if a custom device transformation
115115
or split function is required, a :class:`DataloaderSpec` tuple

composer/optim/pytorch_future.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(self,
7070
interval='step'):
7171
if warmup_method not in ("constant", "linear"):
7272
raise ValueError("Only 'constant' or 'linear' warmup_method accepted, but got {}".format(warmup_method))
73+
7374
self.warmup_factor = warmup_factor
7475
self.warmup_iters = warmup_iters
7576
self.warmup_method = warmup_method

composer/trainer/checkpoint.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@ class CheckpointLoader:
2323
2424
Args:
2525
checkpoint_filepath (str): The path to an existing checkpoint file.
26+
load_weights_only (bool): Whether to only restore the weights from the checkpoint without restoring the associated state.
27+
strict_model_weights (bool): Whether to force that the checkpointed weights must exactly match the model weights.
2628
"""
2729

28-
def __init__(self, checkpoint_filepath: str):
30+
def __init__(self, checkpoint_filepath: str, load_weights_only: bool = False, strict_model_weights: bool = False):
2931
self.state_dict = torch.load(checkpoint_filepath, map_location='cpu')
32+
self.load_weights_only = load_weights_only
33+
self.strict_model_weights = strict_model_weights
3034
self.checkpoint_rng_state = None
3135

3236
def load_checkpoint(self, state: State):
@@ -39,19 +43,22 @@ def load_checkpoint(self, state: State):
3943
The seed that was loaded from the checkpoint if it exists otherwise `None`.
4044
"""
4145

42-
state.load_state_dict(self.state_dict["state"])
43-
self.checkpoint_rng_state = self._get_checkpoint_rng_state(state, self.state_dict["rng"])
44-
45-
if "seed" in self.state_dict:
46-
world_size = ddp.get_world_size()
47-
checkpointed_world_size = len(self.state_dict["seed"])
48-
if world_size != checkpointed_world_size:
49-
warnings.warn(f"Current world size {world_size} does not match the checkpointed world size "
50-
f"{checkpointed_world_size}. The seed will not be restored.")
51-
return
52-
seed_to_restore = self.state_dict["seed"][ddp.get_global_rank()]
53-
seed_all(seed_to_restore)
54-
return seed_to_restore
46+
if self.load_weights_only:
47+
state.load_model_state(self.state_dict['state'], strict=self.strict_model_weights)
48+
else:
49+
state.load_state_dict(self.state_dict["state"])
50+
self.checkpoint_rng_state = self._get_checkpoint_rng_state(state, self.state_dict["rng"])
51+
52+
if "seed" in self.state_dict:
53+
world_size = ddp.get_world_size()
54+
checkpointed_world_size = len(self.state_dict["seed"])
55+
if world_size != checkpointed_world_size:
56+
warnings.warn(f"Current world size {world_size} does not match the checkpointed world size "
57+
f"{checkpointed_world_size}. The seed will not be restored.")
58+
return
59+
seed_to_restore = self.state_dict["seed"][ddp.get_global_rank()]
60+
seed_all(seed_to_restore)
61+
return seed_to_restore
5562

5663
def restore_checkpoint_rng_state(self, state: State, device: Device):
5764
"""Restore the state of all RNG objects in this context from the loaded checkpoint's data.
@@ -82,11 +89,11 @@ def _get_checkpoint_rng_state(self, state: State, checkpoint_rng_state: StateDic
8289
f"RNG state will not be restored.")
8390

8491

85-
class Checkpointer:
92+
class CheckpointSaver:
8693
"""Manager for saving state to checkpoint files.
8794
8895
Args:
89-
checkpoint_folder (str): The path to the folder to store checkpoints in.
96+
checkpoint_folder (str): The path to store checkpoints in.
9097
checkpoint_interval (int): The amount of time units to wait between checkpoints.
9198
checkpoint_interval_unit (str): The unit (`"ep"` or `"it"`) that
9299
`checkpoint_interval` should be measured in.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2021 MosaicML. All Rights Reserved.
2+
3+
import logging
4+
from dataclasses import dataclass
5+
6+
import yahp as hp
7+
8+
from composer.trainer.checkpoint import CheckpointLoader, CheckpointSaver
9+
10+
log = logging.getLogger(__name__)
11+
12+
13+
@dataclass
14+
class CheckpointLoaderHparams(hp.Hparams):
15+
"""Hparams for the :class:`CheckpointLoader`.
16+
17+
See the documentation for the :class:`CheckpointLoader`.
18+
"""
19+
filepath: str = hp.required(doc="Path to the serialized state_dict to recover state from.")
20+
load_weights_only: bool = hp.optional(doc="Whether to only load the weights from the model.", default=False)
21+
strict_model_weights: bool = hp.optional(
22+
doc="Ensure that the set of weights in the checkpoint and model must exactly match.", default=False)
23+
24+
def validate(self):
25+
if not self.load_weights_only and self.strict_model_weights:
26+
raise ValueError(
27+
"Strict cannot be used when load_weights_only is true. Restoring a checkpoint from previous state assumes that the checkpoint should perfectly match the model."
28+
)
29+
30+
def initialize_object(self) -> CheckpointLoader:
31+
return CheckpointLoader(checkpoint_filepath=self.filepath,
32+
load_weights_only=self.load_weights_only,
33+
strict_model_weights=self.strict_model_weights)
34+
35+
36+
@dataclass
37+
class CheckpointSaverHparams(hp.Hparams):
38+
"""Hparams for the :class:`CheckpointSaver`.
39+
40+
See the documentation for the :class:`CheckpointSaver`.
41+
"""
42+
interval_unit: str = hp.required(
43+
doc="Unit for the checkpoint save interval -- should be 'ep' for epochs; 'it' for iterations")
44+
interval: int = hp.required(doc="Interval for checkpointing.")
45+
folder: str = hp.optional(doc="Folder in which to save checkpoint files. Relative to the run directory, if set."
46+
"Defaults to `checkpoints`.",
47+
default="checkpoints")
48+
49+
def validate(self):
50+
if self.interval < 0:
51+
raise ValueError("Checkpointing interval must be greater than zero.")
52+
if self.interval_unit not in ['ep', 'it']:
53+
raise ValueError("Checkpointing interval unit must be one of 'ep' for epochs, or 'it' for iterations.")
54+
55+
def initialize_object(self) -> CheckpointSaver:
56+
return CheckpointSaver(checkpoint_interval_unit=self.interval_unit,
57+
checkpoint_interval=self.interval,
58+
checkpoint_folder=self.folder)

composer/trainer/trainer.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from composer.optim import (ComposedScheduler, CosineAnnealingLRHparams, DecoupledSGDWHparams, OptimizerHparams,
3131
SchedulerHparams, WarmUpLRHparams)
3232
from composer.optim.scheduler import ensure_warmup_last
33-
from composer.trainer.checkpoint import Checkpointer, CheckpointLoader
33+
from composer.trainer.checkpoint import CheckpointLoader, CheckpointSaver
3434
from composer.trainer.deepspeed import DeepSpeedHparams
3535
from composer.trainer.devices.device import Device
3636
from composer.trainer.devices.device_cpu import DeviceCPU
@@ -93,14 +93,15 @@ class Trainer:
9393
log_destinations (List[BaseLoggerBackend], optional): The destinations to log training information to.
9494
(default ``[TQDMLoggerBackend()]``).
9595
callbacks (Sequence[Callback], optional): The callbacks to run during training. (default: ``[]``)
96-
checkpoint_filepath (str, optional): The path to a trainer checkpoint file. If provided
97-
the trainer will load the state (along with it's associated attributes) during initialization.
98-
(default: ``None``)
99-
checkpoint_interval_unit (int, optional): Unit for the checkpoint save interval -- should be 'ep'
100-
for epochs, 'it' for iterations, or None to disable checkpointing. (default: ``None``).
101-
checkpoint_folder (str, optional): The folder to save checkpoints to. Relative to `os.environ.get('RUN_DIRECTORY', '.')`,
102-
(default: ``checkpoints``)
103-
checkpoint_interval (int, optional): The frequency with which to checkpoint. (default: ``1``)
96+
checkpoint_filepath (str): For loading checkpoints, the path to an existing checkpoint file.
97+
load_weights_only (bool): Whether to only restore the weights from the checkpoint without
98+
restoring the associated state.
99+
strict_model_weights (bool, optional): Whether to force that the checkpointed weights must exactly
100+
match the model weights.
101+
checkpoint_folder (str): The path to store checkpoints in.
102+
checkpoint_interval (int): The amount of time units to wait between creating checkpoints.
103+
checkpoint_interval_unit (str, optional): The unit (`"ep"` or `"it"`) that
104+
`checkpoint_interval` should be measured in. Set to ``None`` disables checkpointing. (default: ``None``)
104105
train_subset_num_batches (int, optional): If specified, finish every epoch early after training
105106
on this many batches. This parameter has no effect if it is greater than ``len(train_dataloader)``.
106107
If None (the default), then the entire dataloader will be iterated over.
@@ -150,11 +151,15 @@ def __init__(
150151
log_destinations: Optional[List[BaseLoggerBackend]] = None,
151152
callbacks: Sequence[Callback] = tuple(),
152153

153-
# Checkpoint hparams
154+
# Checkpoint loading hparams
154155
checkpoint_filepath: Optional[str] = None,
156+
checkpoint_load_weights_only: bool = False,
157+
checkpoint_strict_model_weights: bool = False,
158+
159+
# Checkpoint saving hparams
155160
checkpoint_interval_unit: Optional[str] = None,
156-
checkpoint_folder: Optional[str] = "checkpoints",
157-
checkpoint_interval: Optional[int] = 1,
161+
checkpoint_interval: Optional[int] = None,
162+
checkpoint_folder: str = "checkpoints",
158163

159164
# Subset parameters
160165
train_subset_num_batches: Optional[int] = None,
@@ -295,21 +300,26 @@ def __init__(
295300
self.state.optimizers = optimizer
296301
self.state.schedulers = ComposedScheduler(schedulers=schedulers)
297302

298-
self.checkpointer = None
299303
# TODO(#121): get checkpointing working with DeepSpeed.
300-
if checkpoint_folder and checkpoint_interval and checkpoint_interval_unit:
304+
self.checkpoint_saver = None
305+
if checkpoint_interval is not None and checkpoint_interval_unit is not None:
306+
self.checkpoint_saver = CheckpointSaver(checkpoint_interval_unit=checkpoint_interval_unit,
307+
checkpoint_interval=checkpoint_interval,
308+
checkpoint_folder=get_relative_to_run_directory(checkpoint_folder))
309+
301310
if self.deepspeed_enabled:
302311
raise NotImplementedError("Checkpointing is not yet supported with DeepSpeed.")
303-
self.checkpointer = Checkpointer(checkpoint_folder=get_relative_to_run_directory(checkpoint_folder),
304-
checkpoint_interval=checkpoint_interval,
305-
checkpoint_interval_unit=checkpoint_interval_unit)
306312

307-
self.checkpoint_loader = None
308313
# TODO(#121): get checkpointing working with DeepSpeed.
309-
if checkpoint_filepath:
314+
self.checkpoint_loader = None
315+
if checkpoint_filepath is not None:
310316
if self.deepspeed_enabled:
311317
raise NotImplementedError("Checkpointing is not yet supported with DeepSpeed.")
312-
self.checkpoint_loader = CheckpointLoader(checkpoint_filepath=checkpoint_filepath)
318+
319+
self.checkpoint_loader = CheckpointLoader(checkpoint_filepath=checkpoint_filepath,
320+
load_weights_only=checkpoint_load_weights_only,
321+
strict_model_weights=checkpoint_strict_model_weights)
322+
313323
restored_seed = self.checkpoint_loader.load_checkpoint(state=self.state)
314324
# Set the restored seed so that the correct seed will be saved in future checkpoints
315325
# Used to handle the case where another checkpoint is saved after resuming from checkpoint.
@@ -368,6 +378,19 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer:
368378
each evaluation epoch may load a different subset of samples."""))
369379
eval_dataloader = hparams.val_dataset.initialize_object(eval_device_batch_size, hparams.dataloader)
370380

381+
# Checkpoint loading hparams
382+
checkpoint_filepath = hparams.load_checkpoint.filepath if hparams.load_checkpoint is not None else None
383+
checkpoint_load_weights_only = hparams.load_checkpoint.load_weights_only \
384+
if hparams.load_checkpoint is not None else False
385+
checkpoint_strict_model_weights = hparams.load_checkpoint.strict_model_weights \
386+
if hparams.load_checkpoint is not None else False
387+
388+
# Checkpoint saving hparams
389+
checkpoint_interval_unit = hparams.save_checkpoint.interval_unit \
390+
if hparams.save_checkpoint is not None else None
391+
checkpoint_interval = hparams.save_checkpoint.interval if hparams.save_checkpoint is not None else None
392+
checkpoint_folder = hparams.save_checkpoint.folder if hparams.save_checkpoint is not None else "checkpoints"
393+
371394
trainer = cls(
372395
model=model,
373396
train_dataloader=train_dataloader,
@@ -400,11 +423,15 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer:
400423
log_destinations=log_destinations,
401424
callbacks=tuple(callbacks),
402425

403-
# Checkpointing hparams
404-
checkpoint_filepath=hparams.checkpoint_filepath,
405-
checkpoint_interval_unit=hparams.checkpoint_interval_unit,
406-
checkpoint_folder=hparams.checkpoint_folder,
407-
checkpoint_interval=hparams.checkpoint_interval,
426+
# Checkpoint loading hparams
427+
checkpoint_filepath=checkpoint_filepath,
428+
checkpoint_load_weights_only=checkpoint_load_weights_only,
429+
checkpoint_strict_model_weights=checkpoint_strict_model_weights,
430+
431+
# Checkpoint saving hparams
432+
checkpoint_interval_unit=checkpoint_interval_unit,
433+
checkpoint_interval=checkpoint_interval,
434+
checkpoint_folder=checkpoint_folder,
408435

409436
# Subset parameters
410437
train_subset_num_batches=hparams.train_subset_num_batches,
@@ -674,11 +701,12 @@ def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor:
674701
self.eval(is_batch=True)
675702

676703
state.step += 1
677-
if self.checkpointer and self.checkpointer.should_checkpoint(state=state, event=Event.BATCH_END):
678-
self.checkpointer.save_checkpoint(state=state,
679-
seed=self.seed,
680-
device=self.device,
681-
config=self.config)
704+
if self.checkpoint_saver and self.checkpoint_saver.should_checkpoint(state=state,
705+
event=Event.BATCH_END):
706+
self.checkpoint_saver.save_checkpoint(state=state,
707+
seed=self.seed,
708+
device=self.device,
709+
config=self.config)
682710
except BreakEpochException:
683711
log.info(f'Skipping the rest of Epoch {state.epoch}')
684712

@@ -692,8 +720,11 @@ def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor:
692720

693721
state.epoch += 1
694722

695-
if self.checkpointer and self.checkpointer.should_checkpoint(state=state, event=Event.EPOCH_END):
696-
self.checkpointer.save_checkpoint(state=state, seed=self.seed, device=self.device, config=self.config)
723+
if self.checkpoint_saver and self.checkpoint_saver.should_checkpoint(state=state, event=Event.EPOCH_END):
724+
self.checkpoint_saver.save_checkpoint(state=state,
725+
seed=self.seed,
726+
device=self.device,
727+
config=self.config)
697728

698729
self.engine.run_event(Event.TRAINING_END)
699730

0 commit comments

Comments
 (0)