3
3
import logging
4
4
import os
5
5
import random
6
+ import shutil
7
+ import tarfile
8
+ import tempfile
6
9
import warnings
7
10
from typing import Any , Dict , Optional
8
11
18
21
log = logging .getLogger (__name__ )
19
22
20
23
24
+ def get_mosaic_checkpoint_filepath (checkpoint_folder : str , checkpoint_tag : str ):
25
+ return os .path .join (checkpoint_folder , checkpoint_tag , "mosaic_states.pt" )
26
+
27
+
21
28
class CheckpointLoader :
22
29
"""Manager for initializing state and restoring RNG state from existing checkpoints.
23
30
@@ -28,7 +35,7 @@ class CheckpointLoader:
28
35
"""
29
36
30
37
def __init__ (self , checkpoint_filepath : str , load_weights_only : bool = False , strict_model_weights : bool = False ):
31
- self .state_dict = torch . load ( checkpoint_filepath , map_location = 'cpu' )
38
+ self .checkpoint_filepath = checkpoint_filepath
32
39
self .load_weights_only = load_weights_only
33
40
self .strict_model_weights = strict_model_weights
34
41
self .checkpoint_rng_state = None
@@ -42,25 +49,45 @@ def load_checkpoint(self, state: State):
42
49
Returns:
43
50
The seed that was loaded from the checkpoint if it exists otherwise `None`.
44
51
"""
52
+ seed_to_restore = None
45
53
46
- if self .load_weights_only :
47
- state .load_model_state (self .state_dict ['state' ], strict = self .strict_model_weights )
48
- else :
49
- state .load_state_dict (self .state_dict ["state" ])
50
- self .checkpoint_rng_state = self ._get_checkpoint_rng_state (state , self .state_dict ["rng" ])
51
-
52
- if "seed" in self .state_dict :
53
- world_size = ddp .get_world_size ()
54
- checkpointed_world_size = len (self .state_dict ["seed" ])
55
- if world_size != checkpointed_world_size :
56
- warnings .warn (f"Current world size { world_size } does not match the checkpointed world size "
57
- f"{ checkpointed_world_size } . The seed will not be restored." )
58
- return
59
- seed_to_restore = self .state_dict ["seed" ][ddp .get_global_rank ()]
60
- seed_all (seed_to_restore )
61
- return seed_to_restore
62
-
63
- def restore_checkpoint_rng_state (self , state : State , device : Device ):
54
+ with tempfile .TemporaryDirectory () as checkpoint_folder :
55
+ with tarfile .open (self .checkpoint_filepath ) as tarball :
56
+ tarball .extractall (checkpoint_folder )
57
+
58
+ checkpoint_tag = os .listdir (checkpoint_folder )[0 ]
59
+ mosaic_checkpoint_filepath = get_mosaic_checkpoint_filepath (checkpoint_folder , checkpoint_tag )
60
+
61
+ state_dict = torch .load (mosaic_checkpoint_filepath , map_location = 'cpu' )
62
+
63
+ if self .load_weights_only :
64
+ state .load_model_state (state_dict ['state' ], strict = self .strict_model_weights )
65
+ else :
66
+ state .load_state_dict (state_dict ["state" ])
67
+ self .checkpoint_rng_state = self ._get_checkpoint_rng_state (state_dict ["rng" ])
68
+
69
+ if "seed" in state_dict :
70
+ world_size = ddp .get_world_size ()
71
+ checkpointed_world_size = len (state_dict ["seed" ])
72
+ if world_size != checkpointed_world_size :
73
+ warnings .warn (f"Current world size { world_size } does not match the checkpointed world size "
74
+ f"{ checkpointed_world_size } . The seed will not be restored." )
75
+ else :
76
+ seed_to_restore = state_dict ["seed" ][ddp .get_global_rank ()]
77
+ seed_all (seed_to_restore )
78
+
79
+ try :
80
+ import deepspeed
81
+ if isinstance (state .model , deepspeed .DeepSpeedEngine ):
82
+ load_path , _ = state .model .load_checkpoint (checkpoint_folder , checkpoint_tag ) # type: ignore
83
+ if load_path is None :
84
+ raise RuntimeError (f"Failed to load DeepSpeed checkpoint from { self .checkpoint_filepath } " )
85
+ except ImportError :
86
+ pass
87
+
88
+ return seed_to_restore
89
+
90
+ def restore_checkpoint_rng_state (self , device : Device ):
64
91
"""Restore the state of all RNG objects in this context from the loaded checkpoint's data.
65
92
"""
66
93
@@ -79,7 +106,7 @@ def restore_checkpoint_rng_state(self, state: State, device: Device):
79
106
80
107
self .checkpoint_rng_state = None
81
108
82
- def _get_checkpoint_rng_state (self , state : State , checkpoint_rng_state : StateDict ) -> Optional [StateDict ]:
109
+ def _get_checkpoint_rng_state (self , checkpoint_rng_state : StateDict ) -> Optional [StateDict ]:
83
110
original_world_size = len (checkpoint_rng_state ["torch" ])
84
111
if original_world_size == ddp .get_world_size ():
85
112
return checkpoint_rng_state
@@ -139,39 +166,60 @@ def save_checkpoint(self, state: State, seed: int, device: Device, config: Optio
139
166
'rng' : self ._get_rng_state (device = device ), # stored across all ranks
140
167
'seed' : ddp .all_gather_object (seed ),
141
168
}
142
- if ddp .get_global_rank () != 0 :
143
- # only rank 0 saves checkpoints
144
- # Need the check down here so all the DDP syncs will work for generating the checkpoint
145
- return
146
169
147
- # we add the state only on rank 0 since other processes don't have loggers to serialize
148
- state_dict ['state' ] = state .state_dict () # should be the same across all ranks. per-rank state not stored
149
-
150
- if config :
151
- hparams_path = os .path .join (self .checkpoint_folder , "hparams.yaml" )
152
- os .makedirs (self .checkpoint_folder , mode = 0o775 , exist_ok = True )
153
- config_yaml_str = yaml .dump (config )
154
- try :
155
- with open (hparams_path , "x" ) as f :
156
- # Storing the config (ex. hparams) in a separate file so they can be modified before resuming
157
- f .write (config_yaml_str )
158
- except FileExistsError as e :
159
- with open (hparams_path , "r" ) as f :
160
- # comparing the parsed hparams to ignore whitespace and formatting differences
161
- if yaml .safe_load (config_yaml_str ) != yaml .safe_load (f ):
162
- raise RuntimeError (f"The hparams in the existing checkpoint folder { self .checkpoint_folder } "
163
- "differ from those being used in the current training run. "
164
- "Please specify a new checkpoint folder." ) from e
165
170
if self .save_event == Event .EPOCH_END :
166
- filename = f"ep{ state .epoch } .pt "
171
+ tag = f"ep{ state .epoch } "
167
172
elif self .save_event == Event .BATCH_END :
168
- filename = f"it{ state .step } .pt "
173
+ tag = f"it{ state .step } "
169
174
else :
170
175
raise ValueError (f"Invalid checkpoint event: { self .save_event } " )
171
- save_file = os .path .join (self .checkpoint_folder , filename )
172
- with open (save_file , 'xb' ) as f :
173
- torch .save (state_dict , f )
174
- log .info (f'Trainer checkpoint saved to { save_file } ' )
176
+
177
+ try :
178
+ import deepspeed
179
+ if isinstance (state .model , deepspeed .DeepSpeedEngine ):
180
+ state .model .save_checkpoint (self .checkpoint_folder , tag ) # type: ignore
181
+ except ImportError :
182
+ pass
183
+
184
+ if ddp .get_global_rank () == 0 :
185
+ # only rank 0 saves checkpoints
186
+
187
+ # we add the state only on rank 0 since other processes don't have loggers to serialize
188
+ state_dict ['state' ] = state .state_dict () # should be the same across all ranks. per-rank state not stored
189
+
190
+ if config :
191
+ hparams_path = os .path .join (self .checkpoint_folder , "hparams.yaml" )
192
+ os .makedirs (self .checkpoint_folder , mode = 0o775 , exist_ok = True )
193
+ config_yaml_str = yaml .dump (config )
194
+ try :
195
+ with open (hparams_path , "x" ) as f :
196
+ # Storing the config (ex. hparams) in a separate file so they can be modified before resuming
197
+ f .write (config_yaml_str )
198
+ except FileExistsError as e :
199
+ with open (hparams_path , "r" ) as f :
200
+ # comparing the parsed hparams to ignore whitespace and formatting differences
201
+ if yaml .safe_load (config_yaml_str ) != yaml .safe_load (f ):
202
+ raise RuntimeError (
203
+ f"The hparams in the existing checkpoint folder { self .checkpoint_folder } "
204
+ "differ from those being used in the current training run. "
205
+ "Please specify a new checkpoint folder." ) from e
206
+ checkpoint_filepath = os .path .join (self .checkpoint_folder , tag )
207
+ mosaic_states_filepath = get_mosaic_checkpoint_filepath (self .checkpoint_folder , tag )
208
+ if not os .path .exists (checkpoint_filepath ):
209
+ os .makedirs (checkpoint_filepath )
210
+ with open (mosaic_states_filepath , 'xb' ) as f :
211
+ torch .save (state_dict , f )
212
+
213
+ checkpoint_archive_filepath = os .path .join (self .checkpoint_folder , f'{ tag } .tgz' )
214
+ with tarfile .open (checkpoint_archive_filepath , "w:gz" ) as tarball :
215
+ tarball .add (checkpoint_filepath , arcname = tag )
216
+
217
+ shutil .rmtree (checkpoint_filepath )
218
+
219
+ log .info (f'Trainer checkpoint saved to { checkpoint_archive_filepath } ' )
220
+
221
+ # Ensure that the non-rank 0 processes don't exit before the checkpoint is saved.
222
+ ddp .barrier ()
175
223
176
224
def _get_rng_state (self , device : Device ) -> StateDict :
177
225
rng_state = {
0 commit comments