Skip to content

Commit 995fafe

Browse files
authored
Checkpointing & DeepSpeed (#199)
* preliminary test for deepspeed checkpoints * deepspeed tests working * we'll determine later whether that actually does anything * checkpointer understands deepspeed * the tests mostly pass * tests mostly passing * consolidated test * cleanup * formatting * pyright and formatting * isort * change port to be consistent with test script * address comments * fix pyright
1 parent 2907bf0 commit 995fafe

File tree

7 files changed

+245
-124
lines changed

7 files changed

+245
-124
lines changed

composer/core/state.py

100644100755
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
"scaler",
4343
]
4444

45+
# These fields will be serialized using .state_dict(), but will be skipped if DeepSpeed is enabled.
46+
# When DeepSpeed is being used, model and optimizer states are serialized directly by the DeepSpeed engine.
47+
STATE_DICT_SERIALIZATION_FIELDS_SKIP_DEEPSPEED = [
48+
"model",
49+
"_optimizers",
50+
]
51+
4552
# These fields will not be serialized
4653
SKIP_SERIALIZATION_FIELDS = [
4754
"loss", "batch", "outputs", "train_dataloader", "eval_dataloader", "_steps_per_epoch", "_precision_context"
@@ -191,13 +198,22 @@ def state_dict(self) -> types.StateDict:
191198
"""Returns the state as a :class:`dict`."""
192199
state_dict: types.StateDict = {}
193200

201+
deepspeed_enabled = False
202+
try:
203+
import deepspeed
204+
deepspeed_enabled = isinstance(self.model, deepspeed.DeepSpeedEngine)
205+
except ImportError:
206+
pass
207+
194208
for state_field_name, state_field_value in self.__dict__.items():
195209
if state_field_name in SKIP_SERIALIZATION_FIELDS:
196210
continue
197211
elif state_field_name in DIRECT_SERIALIZATION_FIELDS:
198212
state_dict[state_field_name] = state_field_value
199213
continue
200214
elif state_field_name in STATE_DICT_SERIALIZATION_FIELDS:
215+
if deepspeed_enabled and state_field_name in STATE_DICT_SERIALIZATION_FIELDS_SKIP_DEEPSPEED:
216+
continue
201217
if state_field_name == "model":
202218
# Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel
203219
serialized_value = state_field_value.state_dict()
@@ -208,9 +224,12 @@ def state_dict(self) -> types.StateDict:
208224
if obj is not None
209225
}
210226
state_dict[state_field_name] = serialized_value
227+
211228
else:
212229
raise RuntimeError(f"Unable to serialize field {state_field_name}")
213230
state_dict["_is_model_ddp_wrapped"] = isinstance(self.model, DistributedDataParallel)
231+
if deepspeed_enabled:
232+
state_dict["_deepspeed_enabled"] = True
214233
return state_dict
215234

216235
def load_model_state(self, state_dict: types.StateDict, strict: bool):
@@ -237,12 +256,19 @@ def load_state_dict(self, state: types.StateDict, strict: bool = False):
237256
state_dict (types.StateDict): object returned from call to :meth:`state_dict`.
238257
239258
"""
259+
260+
deepspeed_enabled = False
261+
if "_deepspeed_enabled" in state:
262+
deepspeed_enabled = state["_deepspeed_enabled"]
263+
240264
for state_field_name, state_field_value in self.__dict__.items():
241265
if state_field_name in SKIP_SERIALIZATION_FIELDS:
242266
continue
243267
elif state_field_name in DIRECT_SERIALIZATION_FIELDS:
244268
setattr(self, state_field_name, state[state_field_name])
245269
elif state_field_name in STATE_DICT_SERIALIZATION_FIELDS:
270+
if deepspeed_enabled and state_field_name in STATE_DICT_SERIALIZATION_FIELDS_SKIP_DEEPSPEED:
271+
continue
246272
serialized_value = state[state_field_name]
247273

248274
if state_field_name == "model":

composer/trainer/checkpoint.py

100644100755
Lines changed: 96 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import logging
44
import os
55
import random
6+
import shutil
7+
import tarfile
8+
import tempfile
69
import warnings
710
from typing import Any, Dict, Optional
811

@@ -18,6 +21,10 @@
1821
log = logging.getLogger(__name__)
1922

2023

24+
def get_mosaic_checkpoint_filepath(checkpoint_folder: str, checkpoint_tag: str):
25+
return os.path.join(checkpoint_folder, checkpoint_tag, "mosaic_states.pt")
26+
27+
2128
class CheckpointLoader:
2229
"""Manager for initializing state and restoring RNG state from existing checkpoints.
2330
@@ -28,7 +35,7 @@ class CheckpointLoader:
2835
"""
2936

3037
def __init__(self, checkpoint_filepath: str, load_weights_only: bool = False, strict_model_weights: bool = False):
31-
self.state_dict = torch.load(checkpoint_filepath, map_location='cpu')
38+
self.checkpoint_filepath = checkpoint_filepath
3239
self.load_weights_only = load_weights_only
3340
self.strict_model_weights = strict_model_weights
3441
self.checkpoint_rng_state = None
@@ -42,25 +49,45 @@ def load_checkpoint(self, state: State):
4249
Returns:
4350
The seed that was loaded from the checkpoint if it exists otherwise `None`.
4451
"""
52+
seed_to_restore = None
4553

46-
if self.load_weights_only:
47-
state.load_model_state(self.state_dict['state'], strict=self.strict_model_weights)
48-
else:
49-
state.load_state_dict(self.state_dict["state"])
50-
self.checkpoint_rng_state = self._get_checkpoint_rng_state(state, self.state_dict["rng"])
51-
52-
if "seed" in self.state_dict:
53-
world_size = ddp.get_world_size()
54-
checkpointed_world_size = len(self.state_dict["seed"])
55-
if world_size != checkpointed_world_size:
56-
warnings.warn(f"Current world size {world_size} does not match the checkpointed world size "
57-
f"{checkpointed_world_size}. The seed will not be restored.")
58-
return
59-
seed_to_restore = self.state_dict["seed"][ddp.get_global_rank()]
60-
seed_all(seed_to_restore)
61-
return seed_to_restore
62-
63-
def restore_checkpoint_rng_state(self, state: State, device: Device):
54+
with tempfile.TemporaryDirectory() as checkpoint_folder:
55+
with tarfile.open(self.checkpoint_filepath) as tarball:
56+
tarball.extractall(checkpoint_folder)
57+
58+
checkpoint_tag = os.listdir(checkpoint_folder)[0]
59+
mosaic_checkpoint_filepath = get_mosaic_checkpoint_filepath(checkpoint_folder, checkpoint_tag)
60+
61+
state_dict = torch.load(mosaic_checkpoint_filepath, map_location='cpu')
62+
63+
if self.load_weights_only:
64+
state.load_model_state(state_dict['state'], strict=self.strict_model_weights)
65+
else:
66+
state.load_state_dict(state_dict["state"])
67+
self.checkpoint_rng_state = self._get_checkpoint_rng_state(state_dict["rng"])
68+
69+
if "seed" in state_dict:
70+
world_size = ddp.get_world_size()
71+
checkpointed_world_size = len(state_dict["seed"])
72+
if world_size != checkpointed_world_size:
73+
warnings.warn(f"Current world size {world_size} does not match the checkpointed world size "
74+
f"{checkpointed_world_size}. The seed will not be restored.")
75+
else:
76+
seed_to_restore = state_dict["seed"][ddp.get_global_rank()]
77+
seed_all(seed_to_restore)
78+
79+
try:
80+
import deepspeed
81+
if isinstance(state.model, deepspeed.DeepSpeedEngine):
82+
load_path, _ = state.model.load_checkpoint(checkpoint_folder, checkpoint_tag) # type: ignore
83+
if load_path is None:
84+
raise RuntimeError(f"Failed to load DeepSpeed checkpoint from {self.checkpoint_filepath}")
85+
except ImportError:
86+
pass
87+
88+
return seed_to_restore
89+
90+
def restore_checkpoint_rng_state(self, device: Device):
6491
"""Restore the state of all RNG objects in this context from the loaded checkpoint's data.
6592
"""
6693

@@ -79,7 +106,7 @@ def restore_checkpoint_rng_state(self, state: State, device: Device):
79106

80107
self.checkpoint_rng_state = None
81108

82-
def _get_checkpoint_rng_state(self, state: State, checkpoint_rng_state: StateDict) -> Optional[StateDict]:
109+
def _get_checkpoint_rng_state(self, checkpoint_rng_state: StateDict) -> Optional[StateDict]:
83110
original_world_size = len(checkpoint_rng_state["torch"])
84111
if original_world_size == ddp.get_world_size():
85112
return checkpoint_rng_state
@@ -139,39 +166,60 @@ def save_checkpoint(self, state: State, seed: int, device: Device, config: Optio
139166
'rng': self._get_rng_state(device=device), # stored across all ranks
140167
'seed': ddp.all_gather_object(seed),
141168
}
142-
if ddp.get_global_rank() != 0:
143-
# only rank 0 saves checkpoints
144-
# Need the check down here so all the DDP syncs will work for generating the checkpoint
145-
return
146169

147-
# we add the state only on rank 0 since other processes don't have loggers to serialize
148-
state_dict['state'] = state.state_dict() # should be the same across all ranks. per-rank state not stored
149-
150-
if config:
151-
hparams_path = os.path.join(self.checkpoint_folder, "hparams.yaml")
152-
os.makedirs(self.checkpoint_folder, mode=0o775, exist_ok=True)
153-
config_yaml_str = yaml.dump(config)
154-
try:
155-
with open(hparams_path, "x") as f:
156-
# Storing the config (ex. hparams) in a separate file so they can be modified before resuming
157-
f.write(config_yaml_str)
158-
except FileExistsError as e:
159-
with open(hparams_path, "r") as f:
160-
# comparing the parsed hparams to ignore whitespace and formatting differences
161-
if yaml.safe_load(config_yaml_str) != yaml.safe_load(f):
162-
raise RuntimeError(f"The hparams in the existing checkpoint folder {self.checkpoint_folder} "
163-
"differ from those being used in the current training run. "
164-
"Please specify a new checkpoint folder.") from e
165170
if self.save_event == Event.EPOCH_END:
166-
filename = f"ep{state.epoch}.pt"
171+
tag = f"ep{state.epoch}"
167172
elif self.save_event == Event.BATCH_END:
168-
filename = f"it{state.step}.pt"
173+
tag = f"it{state.step}"
169174
else:
170175
raise ValueError(f"Invalid checkpoint event: {self.save_event}")
171-
save_file = os.path.join(self.checkpoint_folder, filename)
172-
with open(save_file, 'xb') as f:
173-
torch.save(state_dict, f)
174-
log.info(f'Trainer checkpoint saved to {save_file}')
176+
177+
try:
178+
import deepspeed
179+
if isinstance(state.model, deepspeed.DeepSpeedEngine):
180+
state.model.save_checkpoint(self.checkpoint_folder, tag) # type: ignore
181+
except ImportError:
182+
pass
183+
184+
if ddp.get_global_rank() == 0:
185+
# only rank 0 saves checkpoints
186+
187+
# we add the state only on rank 0 since other processes don't have loggers to serialize
188+
state_dict['state'] = state.state_dict() # should be the same across all ranks. per-rank state not stored
189+
190+
if config:
191+
hparams_path = os.path.join(self.checkpoint_folder, "hparams.yaml")
192+
os.makedirs(self.checkpoint_folder, mode=0o775, exist_ok=True)
193+
config_yaml_str = yaml.dump(config)
194+
try:
195+
with open(hparams_path, "x") as f:
196+
# Storing the config (ex. hparams) in a separate file so they can be modified before resuming
197+
f.write(config_yaml_str)
198+
except FileExistsError as e:
199+
with open(hparams_path, "r") as f:
200+
# comparing the parsed hparams to ignore whitespace and formatting differences
201+
if yaml.safe_load(config_yaml_str) != yaml.safe_load(f):
202+
raise RuntimeError(
203+
f"The hparams in the existing checkpoint folder {self.checkpoint_folder} "
204+
"differ from those being used in the current training run. "
205+
"Please specify a new checkpoint folder.") from e
206+
checkpoint_filepath = os.path.join(self.checkpoint_folder, tag)
207+
mosaic_states_filepath = get_mosaic_checkpoint_filepath(self.checkpoint_folder, tag)
208+
if not os.path.exists(checkpoint_filepath):
209+
os.makedirs(checkpoint_filepath)
210+
with open(mosaic_states_filepath, 'xb') as f:
211+
torch.save(state_dict, f)
212+
213+
checkpoint_archive_filepath = os.path.join(self.checkpoint_folder, f'{tag}.tgz')
214+
with tarfile.open(checkpoint_archive_filepath, "w:gz") as tarball:
215+
tarball.add(checkpoint_filepath, arcname=tag)
216+
217+
shutil.rmtree(checkpoint_filepath)
218+
219+
log.info(f'Trainer checkpoint saved to {checkpoint_archive_filepath}')
220+
221+
# Ensure that the non-rank 0 processes don't exit before the checkpoint is saved.
222+
ddp.barrier()
175223

176224
def _get_rng_state(self, device: Device) -> StateDict:
177225
rng_state = {

composer/trainer/trainer.py

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class Trainer:
9393
log_destinations (List[BaseLoggerBackend], optional): The destinations to log training information to.
9494
(default ``[TQDMLoggerBackend()]``).
9595
callbacks (Sequence[Callback], optional): The callbacks to run during training. (default: ``[]``)
96-
checkpoint_filepath (str): For loading checkpoints, the path to an existing checkpoint file.
96+
checkpoint_filepath (str): For loading checkpoints, the path to an existing checkpoint.
9797
load_weights_only (bool): Whether to only restore the weights from the checkpoint without
9898
restoring the associated state.
9999
strict_model_weights (bool, optional): Whether to force that the checkpointed weights must exactly
@@ -300,33 +300,49 @@ def __init__(
300300
self.state.optimizers = optimizer
301301
self.state.schedulers = ComposedScheduler(schedulers=schedulers)
302302

303-
# TODO(#121): get checkpointing working with DeepSpeed.
303+
assert isinstance(self.state.model, BaseMosaicModel)
304+
self.original_model = self.state.model # type: ignore # TODO(ravi) -- update the state to add an original model helper
305+
304306
self.checkpoint_saver = None
305-
if checkpoint_interval is not None and checkpoint_interval_unit is not None:
306-
self.checkpoint_saver = CheckpointSaver(checkpoint_interval_unit=checkpoint_interval_unit,
307+
if checkpoint_folder and checkpoint_interval and checkpoint_interval_unit:
308+
self.checkpoint_saver = CheckpointSaver(checkpoint_folder=get_relative_to_run_directory(checkpoint_folder),
307309
checkpoint_interval=checkpoint_interval,
308-
checkpoint_folder=get_relative_to_run_directory(checkpoint_folder))
309-
310-
if self.deepspeed_enabled:
311-
raise NotImplementedError("Checkpointing is not yet supported with DeepSpeed.")
310+
checkpoint_interval_unit=checkpoint_interval_unit)
312311

313-
# TODO(#121): get checkpointing working with DeepSpeed.
314312
self.checkpoint_loader = None
315-
if checkpoint_filepath is not None:
316-
if self.deepspeed_enabled:
317-
raise NotImplementedError("Checkpointing is not yet supported with DeepSpeed.")
318-
313+
if checkpoint_filepath:
319314
self.checkpoint_loader = CheckpointLoader(checkpoint_filepath=checkpoint_filepath,
320315
load_weights_only=checkpoint_load_weights_only,
321316
strict_model_weights=checkpoint_strict_model_weights)
322317

318+
# place the state, model in the proper devices, and initialize from a checkpoint if provided
319+
if self.deepspeed_enabled:
320+
import deepspeed
321+
322+
assert self.deepspeed_hparams is not None
323+
deepspeed_config = self.deepspeed_hparams.initialize_object(self.state, self.grad_clip_norm)
324+
optimizer = ensure_tuple(self.state.optimizers)[0]
325+
(self.state.model, self.state.optimizers, _, _) = deepspeed.initialize(
326+
config=deepspeed_config,
327+
model=self.state.model,
328+
optimizer=optimizer,
329+
)
330+
331+
# If using DeepSpeed, the model must be loaded from checkpoint after the engine has been
332+
# initialized, but if using PyTorch DDP, the model must be loaded before it is wrapped with
333+
# DDP.
334+
if self.checkpoint_loader:
323335
restored_seed = self.checkpoint_loader.load_checkpoint(state=self.state)
324-
# Set the restored seed so that the correct seed will be saved in future checkpoints
325-
# Used to handle the case where another checkpoint is saved after resuming from checkpoint.
326-
# In this case, self.seed is stored in the second checkpoint so it must have the correct value.
327336
if restored_seed is not None:
328337
self.seed = restored_seed
329338

339+
if not self.deepspeed_enabled:
340+
self.state.model = self.device.module_to_device(self.state.model)
341+
self.state.optimizers = map_collection(self.state.optimizers, self.device.optimizer_to_device)
342+
343+
# wrap model with DDP
344+
self.state.model = ddp.prepare_module(self.state.model, self.find_unused_parameters)
345+
330346
@classmethod
331347
def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer:
332348
"""Instantiate a Trainer using a `TrainerHparams` object.
@@ -551,28 +567,6 @@ def _train_loop(self) -> None:
551567
raise NotImplementedError("The Mosaic trainer only supports one optimizer; "
552568
f"found {len(ensure_tuple(state.optimizers))} optimizers")
553569

554-
assert isinstance(state.model, BaseMosaicModel)
555-
self.original_model = state.model # type: ignore # TODO(ravi) -- update the state to add an original model helper
556-
557-
# place the state, model in the proper devices
558-
if self.deepspeed_enabled:
559-
import deepspeed
560-
561-
assert self.deepspeed_hparams is not None
562-
deepspeed_config = self.deepspeed_hparams.initialize_object(state, self.grad_clip_norm)
563-
optimizer = ensure_tuple(state.optimizers)[0]
564-
(state.model, state.optimizers, _, _) = deepspeed.initialize(
565-
config=deepspeed_config,
566-
model=state.model,
567-
optimizer=optimizer,
568-
)
569-
else:
570-
state.model = self.device.module_to_device(state.model)
571-
state.optimizers = map_collection(state.optimizers, self.device.optimizer_to_device)
572-
573-
# wrap model with DDP
574-
state.model = ddp.prepare_module(state.model, self.find_unused_parameters)
575-
576570
# print training start
577571
self.logger.metric_fit({"trainer/algorithms": [str(algo) for algo in self.engine.algorithms]})
578572

@@ -609,7 +603,7 @@ def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor:
609603

610604
if self.state.batch_idx == 0 and self.checkpoint_loader:
611605
# only restore the rng state here if the step in the current epoch is zero.
612-
self.checkpoint_loader.restore_checkpoint_rng_state(self.state, self.device)
606+
self.checkpoint_loader.restore_checkpoint_rng_state(self.device)
613607

614608
for _ in range(state.epoch, state.max_epochs):
615609
try:
@@ -625,7 +619,7 @@ def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor:
625619
# if resuming, skip dataloader forward to the minibatch index
626620
if batch_idx < self.state.batch_idx:
627621
if self.checkpoint_loader:
628-
self.checkpoint_loader.restore_checkpoint_rng_state(self.state, self.device)
622+
self.checkpoint_loader.restore_checkpoint_rng_state(self.device)
629623
continue
630624

631625
state.last_batch_size = self._get_batch_size(state.batch)

0 commit comments

Comments
 (0)