@@ -141,7 +141,9 @@ def load_optimizer(
141
141
ckp_tp_size = ckp_optimizer_config ["parallelism" ]["tp_size" ]
142
142
ckp_dp_size = ckp_optimizer_config ["parallelism" ]["dp_size" ]
143
143
144
- if int (ckp_tp_size ) != int (parallel_context .tp_pg .size ()):
144
+ if int (ckp_tp_size ) != int (parallel_context .tp_pg .size ()) or int (ckp_pp_size ) != int (
145
+ parallel_context .pp_pg .size ()
146
+ ):
145
147
assert (
146
148
param_shard_metadata is not None
147
149
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: { ckp_tp_size } , current tp_size: { parallel_context .tp_pg .size ()} "
@@ -179,14 +181,17 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
179
181
180
182
model_state_dict = model .state_dict ()
181
183
new_optim_state_dict = optimizer .state_dict ()
184
+ # TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
182
185
OPTIMIZER_STATE_NAMES = sorted (ckp_sharded_optim_states [(0 , 0 )]["state" ][0 ].keys () - ["step" ])
183
186
# NOTE: because we can only resume training with the same optimizer type
184
187
# (0, 0) = (pp_rank, tp_rank)
185
188
# NOTE: also we don't merge "step" because it's just a scalar
186
-
187
- param_names = sorted (model_state_dict .items (), key = lambda x : x [0 ])
188
- for param_name , _ in tqdm (
189
- param_names ,
189
+ param_names = list (model_state_dict .keys ())
190
+ new_optim_state_param_names = {}
191
+ # NOTE: iterates through all model parameters in the local pipeline parallel rank (hence, might not be the full model).
192
+ # Since model parameters and optimizer states are aligned, loads only the optimizer states for these parameters from the checkpoint shards.
193
+ for param_index , param_name in tqdm (
194
+ enumerate (param_names ),
190
195
disable = dist .get_rank (parallel_context .world_pg ) != 0 ,
191
196
desc = "Topology-agnostic optimizer loading" ,
192
197
):
@@ -198,28 +203,49 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
198
203
if not isinstance (param , NanotronParameter ):
199
204
raise NotImplementedError ("Parameters are required to be NanotronParameter" )
200
205
206
+ # NOTE: for tied parameters, the metadata is stored using the parameter name,
207
+ # while the data is stored using the name of the main tied parameter,
208
+ # which may be different (e.g. `model.token_position_embeddings.pp_block.token_embedding.weight`
209
+ # for `model.lm_head.pp_block.weight`).
210
+ base_name = param .get_tied_info ().name if param .is_tied else param_name
211
+ if param_name != base_name :
212
+ # NOTE: skip tied parameter if main tied parameter has already been loaded
213
+ # (not always the case if pipeline parallel)
214
+ if base_name in new_optim_state_param_names .values ():
215
+ continue
216
+ new_optim_state_param_names [param_index ] = base_name
217
+
201
218
if param .is_sharded :
202
219
# NOTE: optimizer states's shape is equal to the parameter's shape
203
220
# NOTE: sometines an unsharded parameter's shape differ
204
221
# from an unsharded optimizer state's shape
205
222
new_shard_metadata = param .get_sharded_info ()
206
223
new_unshared_shape = new_shard_metadata .unsharded_shape
207
-
208
- # NOTE: merging optimizer states
209
- optim_state_index = find_optim_index_from_param_name (
210
- param_name , ckp_sharded_optim_states , is_zero1 = False
211
- )
212
-
213
- new_optim_state_dict ["state" ][optim_state_index ] = {}
224
+ new_optim_state_dict ["state" ][param_index ] = {}
225
+ # NOTE: restore each state tensor (e.g. exg_avg) by iterating through
226
+ # the optimizer state shards saved using the previous topology
214
227
for state_key in OPTIMIZER_STATE_NAMES :
215
228
# TODO(xrsrke): free the memory of the shards that isn't
216
229
# corresponding to the current rank
217
230
buffer = torch .zeros_like (param , device = "cuda" )
218
231
unsharded_buffer = torch .empty (new_unshared_shape , device = "cuda" )
219
232
220
233
for (pp_rank , tp_rank ), ckp_optim_state in ckp_sharded_optim_states .items ():
221
- ckp_shard_metadata = get_checkpoint_state_metadata (param_name , pp_rank , tp_rank )
222
- ckp_shard_data = ckp_optim_state ["state" ][optim_state_index ][state_key ]
234
+ old_optim_state_index = find_optim_index_from_param_name (
235
+ base_name , ckp_sharded_optim_states , is_zero1 = False , pp_rank = pp_rank
236
+ )
237
+ if old_optim_state_index is None :
238
+ continue # NOTE: param is not in this pp shard
239
+ ckp_shard_data = ckp_optim_state ["state" ][old_optim_state_index ][state_key ]
240
+ # NOTE: the metadata for the main parameter of a tied parameter might be in a
241
+ # different pipeline parallel shard.
242
+ if param .is_tied :
243
+ metadata_pp_rank = next (
244
+ iter (param_shard_metadata [param_name .replace ("module." , "" )].keys ())
245
+ )[0 ]
246
+ else :
247
+ metadata_pp_rank = pp_rank
248
+ ckp_shard_metadata = get_checkpoint_state_metadata (param_name , metadata_pp_rank , tp_rank )
223
249
224
250
# NOTE: if the checkpoint is from a Zero-1 optimizer,
225
251
# so it's flattened, so we need to reshape it
@@ -229,7 +255,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
229
255
orig_shape = [int (dim ) for dim in orig_shape ]
230
256
ckp_shard_data = ckp_shard_data .view (orig_shape )
231
257
232
- new_optim_state_dict ["state" ][optim_state_index ][state_key ] = merge_and_shard_tp_tensors (
258
+ new_optim_state_dict ["state" ][param_index ][state_key ] = merge_and_shard_tp_tensors (
233
259
buffer ,
234
260
unsharded_buffer ,
235
261
[
@@ -240,17 +266,16 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
240
266
241
267
if ckp_optim_type == ZeroDistributedOptimizer .__name__ :
242
268
# NOTE: flatten the optimizer states
243
- new_optim_state_dict ["state" ][optim_state_index ][state_key ] = new_optim_state_dict [
244
- "state"
245
- ][optim_state_index ][state_key ].flatten ()
246
-
247
- new_optim_state_dict ["state" ][optim_state_index ]["step" ] = ckp_optim_state ["state" ][optim_state_index ][
248
- "step"
249
- ]
250
-
251
- # NOTE: since all shards have the same optim state names
252
- # so we take the first shard
253
- new_optim_state_dict ["names" ] = ckp_sharded_optim_states [(0 , 0 )]["names" ]
269
+ new_optim_state_dict ["state" ][param_index ][state_key ] = new_optim_state_dict ["state" ][
270
+ param_index
271
+ ][state_key ].flatten ()
272
+ # NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
273
+ # try to get the step value as well.
274
+ step = ckp_optim_state ["state" ][old_optim_state_index ].get ("step" )
275
+ if step is not None :
276
+ new_optim_state_dict ["state" ][param_index ]["step" ] = step
277
+
278
+ new_optim_state_dict ["names" ] = new_optim_state_param_names
254
279
state_dict = new_optim_state_dict
255
280
else :
256
281
# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
0 commit comments