@@ -144,7 +144,9 @@ def load_optimizer(
144
144
ckp_dp_size = ckp_optimizer_config ["parallelism" ]["dp_size" ]
145
145
ckpt_expert_parallel_size = ckp_optimizer_config ["parallelism" ]["expert_parallel_size" ]
146
146
147
- if int (ckp_tp_size ) != int (parallel_context .tp_pg .size ()):
147
+ if int (ckp_tp_size ) != int (parallel_context .tp_pg .size ()) or int (ckp_pp_size ) != int (
148
+ parallel_context .pp_pg .size ()
149
+ ):
148
150
assert (
149
151
param_shard_metadata is not None
150
152
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: { ckp_tp_size } , current tp_size: { parallel_context .tp_pg .size ()} "
@@ -182,14 +184,17 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
182
184
183
185
model_state_dict = model .state_dict ()
184
186
new_optim_state_dict = optimizer .state_dict ()
187
+ # TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
185
188
OPTIMIZER_STATE_NAMES = sorted (ckp_sharded_optim_states [(0 , 0 )]["state" ][0 ].keys () - ["step" ])
186
189
# NOTE: because we can only resume training with the same optimizer type
187
190
# (0, 0) = (pp_rank, tp_rank)
188
191
# NOTE: also we don't merge "step" because it's just a scalar
189
-
190
- param_names = sorted (model_state_dict .items (), key = lambda x : x [0 ])
191
- for param_name , _ in tqdm (
192
- param_names ,
192
+ param_names = list (model_state_dict .keys ())
193
+ new_optim_state_param_names = {}
194
+ # NOTE: iterates through all model parameters in the local pipeline parallel rank (hence, might not be the full model).
195
+ # Since model parameters and optimizer states are aligned, loads only the optimizer states for these parameters from the checkpoint shards.
196
+ for param_index , param_name in tqdm (
197
+ enumerate (param_names ),
193
198
disable = dist .get_rank (parallel_context .world_pg ) != 0 ,
194
199
desc = "Topology-agnostic optimizer loading" ,
195
200
):
@@ -201,28 +206,49 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
201
206
if not isinstance (param , NanotronParameter ):
202
207
raise NotImplementedError ("Parameters are required to be NanotronParameter" )
203
208
209
+ # NOTE: for tied parameters, the metadata is stored using the parameter name,
210
+ # while the data is stored using the name of the main tied parameter,
211
+ # which may be different (e.g. `model.token_position_embeddings.pp_block.token_embedding.weight`
212
+ # for `model.lm_head.pp_block.weight`).
213
+ base_name = param .get_tied_info ().name if param .is_tied else param_name
214
+ if param_name != base_name :
215
+ # NOTE: skip tied parameter if main tied parameter has already been loaded
216
+ # (not always the case if pipeline parallel)
217
+ if base_name in new_optim_state_param_names .values ():
218
+ continue
219
+ new_optim_state_param_names [param_index ] = base_name
220
+
204
221
if param .is_sharded :
205
222
# NOTE: optimizer states's shape is equal to the parameter's shape
206
223
# NOTE: sometines an unsharded parameter's shape differ
207
224
# from an unsharded optimizer state's shape
208
225
new_shard_metadata = param .get_sharded_info ()
209
226
new_unshared_shape = new_shard_metadata .unsharded_shape
210
-
211
- # NOTE: merging optimizer states
212
- optim_state_index = find_optim_index_from_param_name (
213
- param_name , ckp_sharded_optim_states , is_zero1 = False
214
- )
215
-
216
- new_optim_state_dict ["state" ][optim_state_index ] = {}
227
+ new_optim_state_dict ["state" ][param_index ] = {}
228
+ # NOTE: restore each state tensor (e.g. exg_avg) by iterating through
229
+ # the optimizer state shards saved using the previous topology
217
230
for state_key in OPTIMIZER_STATE_NAMES :
218
231
# TODO(xrsrke): free the memory of the shards that isn't
219
232
# corresponding to the current rank
220
233
buffer = torch .zeros_like (param , device = "cuda" )
221
234
unsharded_buffer = torch .empty (new_unshared_shape , device = "cuda" )
222
235
223
236
for (pp_rank , tp_rank ), ckp_optim_state in ckp_sharded_optim_states .items ():
224
- ckp_shard_metadata = get_checkpoint_state_metadata (param_name , pp_rank , tp_rank )
225
- ckp_shard_data = ckp_optim_state ["state" ][optim_state_index ][state_key ]
237
+ old_optim_state_index = find_optim_index_from_param_name (
238
+ base_name , ckp_sharded_optim_states , is_zero1 = False , pp_rank = pp_rank
239
+ )
240
+ if old_optim_state_index is None :
241
+ continue # NOTE: param is not in this pp shard
242
+ ckp_shard_data = ckp_optim_state ["state" ][old_optim_state_index ][state_key ]
243
+ # NOTE: the metadata for the main parameter of a tied parameter might be in a
244
+ # different pipeline parallel shard.
245
+ if param .is_tied :
246
+ metadata_pp_rank = next (
247
+ iter (param_shard_metadata [param_name .replace ("module." , "" )].keys ())
248
+ )[0 ]
249
+ else :
250
+ metadata_pp_rank = pp_rank
251
+ ckp_shard_metadata = get_checkpoint_state_metadata (param_name , metadata_pp_rank , tp_rank )
226
252
227
253
# NOTE: if the checkpoint is from a Zero-1 optimizer,
228
254
# so it's flattened, so we need to reshape it
@@ -232,7 +258,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
232
258
orig_shape = [int (dim ) for dim in orig_shape ]
233
259
ckp_shard_data = ckp_shard_data .view (orig_shape )
234
260
235
- new_optim_state_dict ["state" ][optim_state_index ][state_key ] = merge_and_shard_tp_tensors (
261
+ new_optim_state_dict ["state" ][param_index ][state_key ] = merge_and_shard_tp_tensors (
236
262
buffer ,
237
263
unsharded_buffer ,
238
264
[
@@ -243,17 +269,16 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
243
269
244
270
if ckp_optim_type == ZeroDistributedOptimizer .__name__ :
245
271
# NOTE: flatten the optimizer states
246
- new_optim_state_dict ["state" ][optim_state_index ][state_key ] = new_optim_state_dict [
247
- "state"
248
- ][optim_state_index ][state_key ].flatten ()
249
-
250
- new_optim_state_dict ["state" ][optim_state_index ]["step" ] = ckp_optim_state ["state" ][optim_state_index ][
251
- "step"
252
- ]
253
-
254
- # NOTE: since all shards have the same optim state names
255
- # so we take the first shard
256
- new_optim_state_dict ["names" ] = ckp_sharded_optim_states [(0 , 0 )]["names" ]
272
+ new_optim_state_dict ["state" ][param_index ][state_key ] = new_optim_state_dict ["state" ][
273
+ param_index
274
+ ][state_key ].flatten ()
275
+ # NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
276
+ # try to get the step value as well.
277
+ step = ckp_optim_state ["state" ][old_optim_state_index ].get ("step" )
278
+ if step is not None :
279
+ new_optim_state_dict ["state" ][param_index ]["step" ] = step
280
+
281
+ new_optim_state_dict ["names" ] = new_optim_state_param_names
257
282
state_dict = new_optim_state_dict
258
283
else :
259
284
# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
0 commit comments