Skip to content

Commit 372fdc1

Browse files
authored
Merge pull request #71 from nopperl/topology-agnostic-loading
Implement pipeline parallel size-agnostic optimizer state loading
2 parents 1676cec + f6dd495 commit 372fdc1

File tree

2 files changed

+55
-29
lines changed

2 files changed

+55
-29
lines changed

src/nanotron/optim/zero.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,16 +348,17 @@ def find_optim_index_from_param_name(
348348
# NOTE: (pp_rank, dp_rank, tp_rank) or (pp_rank, tp_rank)
349349
ckp_sharded_optim_states: Union[Tuple[Tuple[int, int, int], torch.Tensor], Tuple[Tuple[int, int], torch.Tensor]],
350350
is_zero1: bool,
351+
pp_rank=0,
351352
) -> int:
352353
param_name = param_name.replace("module.", "")
353354
# NOTE: since all shards have the same optim state names
354-
# so we take the first shard
355+
# so we take the first shard (except optionally the pp dimension)
355356
if is_zero1 is True:
356357
# NOTE: (pp_rank, dp_rank, tp_rank)
357-
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(0, 0, 0)]["names"]
358+
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0, 0)]["names"]
358359
else:
359360
# NOTE: (pp_rank, tp_rank)
360-
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(0, 0)]["names"]
361+
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0)]["names"]
361362

362363
return next((k for k, v in OPTIM_STATE_INDEX_TO_PARAM_NAME.items() if v == param_name), None)
363364

src/nanotron/serialize/optimizer.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def load_optimizer(
144144
ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"]
145145
ckpt_expert_parallel_size = ckp_optimizer_config["parallelism"]["expert_parallel_size"]
146146

147-
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()):
147+
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int(
148+
parallel_context.pp_pg.size()
149+
):
148150
assert (
149151
param_shard_metadata is not None
150152
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}"
@@ -182,14 +184,17 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
182184

183185
model_state_dict = model.state_dict()
184186
new_optim_state_dict = optimizer.state_dict()
187+
# TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
185188
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
186189
# NOTE: because we can only resume training with the same optimizer type
187190
# (0, 0) = (pp_rank, tp_rank)
188191
# NOTE: also we don't merge "step" because it's just a scalar
189-
190-
param_names = sorted(model_state_dict.items(), key=lambda x: x[0])
191-
for param_name, _ in tqdm(
192-
param_names,
192+
param_names = list(model_state_dict.keys())
193+
new_optim_state_param_names = {}
194+
# NOTE: iterates through all model parameters in the local pipeline parallel rank (hence, might not be the full model).
195+
# Since model parameters and optimizer states are aligned, loads only the optimizer states for these parameters from the checkpoint shards.
196+
for param_index, param_name in tqdm(
197+
enumerate(param_names),
193198
disable=dist.get_rank(parallel_context.world_pg) != 0,
194199
desc="Topology-agnostic optimizer loading",
195200
):
@@ -201,28 +206,49 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
201206
if not isinstance(param, NanotronParameter):
202207
raise NotImplementedError("Parameters are required to be NanotronParameter")
203208

209+
# NOTE: for tied parameters, the metadata is stored using the parameter name,
210+
# while the data is stored using the name of the main tied parameter,
211+
# which may be different (e.g. `model.token_position_embeddings.pp_block.token_embedding.weight`
212+
# for `model.lm_head.pp_block.weight`).
213+
base_name = param.get_tied_info().name if param.is_tied else param_name
214+
if param_name != base_name:
215+
# NOTE: skip tied parameter if main tied parameter has already been loaded
216+
# (not always the case if pipeline parallel)
217+
if base_name in new_optim_state_param_names.values():
218+
continue
219+
new_optim_state_param_names[param_index] = base_name
220+
204221
if param.is_sharded:
205222
# NOTE: optimizer states's shape is equal to the parameter's shape
206223
# NOTE: sometines an unsharded parameter's shape differ
207224
# from an unsharded optimizer state's shape
208225
new_shard_metadata = param.get_sharded_info()
209226
new_unshared_shape = new_shard_metadata.unsharded_shape
210-
211-
# NOTE: merging optimizer states
212-
optim_state_index = find_optim_index_from_param_name(
213-
param_name, ckp_sharded_optim_states, is_zero1=False
214-
)
215-
216-
new_optim_state_dict["state"][optim_state_index] = {}
227+
new_optim_state_dict["state"][param_index] = {}
228+
# NOTE: restore each state tensor (e.g. exg_avg) by iterating through
229+
# the optimizer state shards saved using the previous topology
217230
for state_key in OPTIMIZER_STATE_NAMES:
218231
# TODO(xrsrke): free the memory of the shards that isn't
219232
# corresponding to the current rank
220233
buffer = torch.zeros_like(param, device="cuda")
221234
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda")
222235

223236
for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
224-
ckp_shard_metadata = get_checkpoint_state_metadata(param_name, pp_rank, tp_rank)
225-
ckp_shard_data = ckp_optim_state["state"][optim_state_index][state_key]
237+
old_optim_state_index = find_optim_index_from_param_name(
238+
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
239+
)
240+
if old_optim_state_index is None:
241+
continue # NOTE: param is not in this pp shard
242+
ckp_shard_data = ckp_optim_state["state"][old_optim_state_index][state_key]
243+
# NOTE: the metadata for the main parameter of a tied parameter might be in a
244+
# different pipeline parallel shard.
245+
if param.is_tied:
246+
metadata_pp_rank = next(
247+
iter(param_shard_metadata[param_name.replace("module.", "")].keys())
248+
)[0]
249+
else:
250+
metadata_pp_rank = pp_rank
251+
ckp_shard_metadata = get_checkpoint_state_metadata(param_name, metadata_pp_rank, tp_rank)
226252

227253
# NOTE: if the checkpoint is from a Zero-1 optimizer,
228254
# so it's flattened, so we need to reshape it
@@ -232,7 +258,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
232258
orig_shape = [int(dim) for dim in orig_shape]
233259
ckp_shard_data = ckp_shard_data.view(orig_shape)
234260

235-
new_optim_state_dict["state"][optim_state_index][state_key] = merge_and_shard_tp_tensors(
261+
new_optim_state_dict["state"][param_index][state_key] = merge_and_shard_tp_tensors(
236262
buffer,
237263
unsharded_buffer,
238264
[
@@ -243,17 +269,16 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
243269

244270
if ckp_optim_type == ZeroDistributedOptimizer.__name__:
245271
# NOTE: flatten the optimizer states
246-
new_optim_state_dict["state"][optim_state_index][state_key] = new_optim_state_dict[
247-
"state"
248-
][optim_state_index][state_key].flatten()
249-
250-
new_optim_state_dict["state"][optim_state_index]["step"] = ckp_optim_state["state"][optim_state_index][
251-
"step"
252-
]
253-
254-
# NOTE: since all shards have the same optim state names
255-
# so we take the first shard
256-
new_optim_state_dict["names"] = ckp_sharded_optim_states[(0, 0)]["names"]
272+
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][
273+
param_index
274+
][state_key].flatten()
275+
# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
276+
# try to get the step value as well.
277+
step = ckp_optim_state["state"][old_optim_state_index].get("step")
278+
if step is not None:
279+
new_optim_state_dict["state"][param_index]["step"] = step
280+
281+
new_optim_state_dict["names"] = new_optim_state_param_names
257282
state_dict = new_optim_state_dict
258283
else:
259284
# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely

0 commit comments

Comments
 (0)