Skip to content

Commit f6dd495

Browse files
committed
Implement pipeline parallel-agnostic optimizer state loading
1 parent 31aa4c4 commit f6dd495

File tree

2 files changed

+55
-29
lines changed

2 files changed

+55
-29
lines changed

src/nanotron/optim/zero.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,16 +348,17 @@ def find_optim_index_from_param_name(
348348
# NOTE: (pp_rank, dp_rank, tp_rank) or (pp_rank, tp_rank)
349349
ckp_sharded_optim_states: Union[Tuple[Tuple[int, int, int], torch.Tensor], Tuple[Tuple[int, int], torch.Tensor]],
350350
is_zero1: bool,
351+
pp_rank=0,
351352
) -> int:
352353
param_name = param_name.replace("module.", "")
353354
# NOTE: since all shards have the same optim state names
354-
# so we take the first shard
355+
# so we take the first shard (except optionally the pp dimension)
355356
if is_zero1 is True:
356357
# NOTE: (pp_rank, dp_rank, tp_rank)
357-
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(0, 0, 0)]["names"]
358+
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0, 0)]["names"]
358359
else:
359360
# NOTE: (pp_rank, tp_rank)
360-
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(0, 0)]["names"]
361+
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0)]["names"]
361362

362363
return next((k for k, v in OPTIM_STATE_INDEX_TO_PARAM_NAME.items() if v == param_name), None)
363364

src/nanotron/serialize/optimizer.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def load_optimizer(
141141
ckp_tp_size = ckp_optimizer_config["parallelism"]["tp_size"]
142142
ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"]
143143

144-
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()):
144+
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int(
145+
parallel_context.pp_pg.size()
146+
):
145147
assert (
146148
param_shard_metadata is not None
147149
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}"
@@ -179,14 +181,17 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
179181

180182
model_state_dict = model.state_dict()
181183
new_optim_state_dict = optimizer.state_dict()
184+
# TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
182185
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
183186
# NOTE: because we can only resume training with the same optimizer type
184187
# (0, 0) = (pp_rank, tp_rank)
185188
# NOTE: also we don't merge "step" because it's just a scalar
186-
187-
param_names = sorted(model_state_dict.items(), key=lambda x: x[0])
188-
for param_name, _ in tqdm(
189-
param_names,
189+
param_names = list(model_state_dict.keys())
190+
new_optim_state_param_names = {}
191+
# NOTE: iterates through all model parameters in the local pipeline parallel rank (hence, might not be the full model).
192+
# Since model parameters and optimizer states are aligned, loads only the optimizer states for these parameters from the checkpoint shards.
193+
for param_index, param_name in tqdm(
194+
enumerate(param_names),
190195
disable=dist.get_rank(parallel_context.world_pg) != 0,
191196
desc="Topology-agnostic optimizer loading",
192197
):
@@ -198,28 +203,49 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
198203
if not isinstance(param, NanotronParameter):
199204
raise NotImplementedError("Parameters are required to be NanotronParameter")
200205

206+
# NOTE: for tied parameters, the metadata is stored using the parameter name,
207+
# while the data is stored using the name of the main tied parameter,
208+
# which may be different (e.g. `model.token_position_embeddings.pp_block.token_embedding.weight`
209+
# for `model.lm_head.pp_block.weight`).
210+
base_name = param.get_tied_info().name if param.is_tied else param_name
211+
if param_name != base_name:
212+
# NOTE: skip tied parameter if main tied parameter has already been loaded
213+
# (not always the case if pipeline parallel)
214+
if base_name in new_optim_state_param_names.values():
215+
continue
216+
new_optim_state_param_names[param_index] = base_name
217+
201218
if param.is_sharded:
202219
# NOTE: optimizer states's shape is equal to the parameter's shape
203220
# NOTE: sometines an unsharded parameter's shape differ
204221
# from an unsharded optimizer state's shape
205222
new_shard_metadata = param.get_sharded_info()
206223
new_unshared_shape = new_shard_metadata.unsharded_shape
207-
208-
# NOTE: merging optimizer states
209-
optim_state_index = find_optim_index_from_param_name(
210-
param_name, ckp_sharded_optim_states, is_zero1=False
211-
)
212-
213-
new_optim_state_dict["state"][optim_state_index] = {}
224+
new_optim_state_dict["state"][param_index] = {}
225+
# NOTE: restore each state tensor (e.g. exg_avg) by iterating through
226+
# the optimizer state shards saved using the previous topology
214227
for state_key in OPTIMIZER_STATE_NAMES:
215228
# TODO(xrsrke): free the memory of the shards that isn't
216229
# corresponding to the current rank
217230
buffer = torch.zeros_like(param, device="cuda")
218231
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda")
219232

220233
for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
221-
ckp_shard_metadata = get_checkpoint_state_metadata(param_name, pp_rank, tp_rank)
222-
ckp_shard_data = ckp_optim_state["state"][optim_state_index][state_key]
234+
old_optim_state_index = find_optim_index_from_param_name(
235+
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
236+
)
237+
if old_optim_state_index is None:
238+
continue # NOTE: param is not in this pp shard
239+
ckp_shard_data = ckp_optim_state["state"][old_optim_state_index][state_key]
240+
# NOTE: the metadata for the main parameter of a tied parameter might be in a
241+
# different pipeline parallel shard.
242+
if param.is_tied:
243+
metadata_pp_rank = next(
244+
iter(param_shard_metadata[param_name.replace("module.", "")].keys())
245+
)[0]
246+
else:
247+
metadata_pp_rank = pp_rank
248+
ckp_shard_metadata = get_checkpoint_state_metadata(param_name, metadata_pp_rank, tp_rank)
223249

224250
# NOTE: if the checkpoint is from a Zero-1 optimizer,
225251
# so it's flattened, so we need to reshape it
@@ -229,7 +255,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
229255
orig_shape = [int(dim) for dim in orig_shape]
230256
ckp_shard_data = ckp_shard_data.view(orig_shape)
231257

232-
new_optim_state_dict["state"][optim_state_index][state_key] = merge_and_shard_tp_tensors(
258+
new_optim_state_dict["state"][param_index][state_key] = merge_and_shard_tp_tensors(
233259
buffer,
234260
unsharded_buffer,
235261
[
@@ -240,17 +266,16 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
240266

241267
if ckp_optim_type == ZeroDistributedOptimizer.__name__:
242268
# NOTE: flatten the optimizer states
243-
new_optim_state_dict["state"][optim_state_index][state_key] = new_optim_state_dict[
244-
"state"
245-
][optim_state_index][state_key].flatten()
246-
247-
new_optim_state_dict["state"][optim_state_index]["step"] = ckp_optim_state["state"][optim_state_index][
248-
"step"
249-
]
250-
251-
# NOTE: since all shards have the same optim state names
252-
# so we take the first shard
253-
new_optim_state_dict["names"] = ckp_sharded_optim_states[(0, 0)]["names"]
269+
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][
270+
param_index
271+
][state_key].flatten()
272+
# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
273+
# try to get the step value as well.
274+
step = ckp_optim_state["state"][old_optim_state_index].get("step")
275+
if step is not None:
276+
new_optim_state_dict["state"][param_index]["step"] = step
277+
278+
new_optim_state_dict["names"] = new_optim_state_param_names
254279
state_dict = new_optim_state_dict
255280
else:
256281
# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely

0 commit comments

Comments
 (0)