Skip to content

Commit 4cdc2cd

Browse files
author
bigning
authored
fix 2.4.1ckpt (#3629)
1 parent d2e1d5e commit 4cdc2cd

File tree

2 files changed

+31
-44
lines changed

2 files changed

+31
-44
lines changed

composer/trainer/_patch_pytorch.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -945,8 +945,7 @@ def unshard_with_sync(self):
945945

946946
if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
947947
torch.__version__,
948-
) < version.parse('2.4.1'):
949-
# 2.4.0 only patch
948+
) < version.parse('2.4.2'):
950949
# PyTorch issue: https://github.com/pytorch/pytorch/issues/133923
951950
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
952951
from typing import Mapping, Collection
@@ -1003,9 +1002,6 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
10031002
for key, value in state_dict.items():
10041003
_traverse_obj((str(key),), value)
10051004

1006-
if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
1007-
torch.__version__,
1008-
) < version.parse('2.4.2'):
10091005
# Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks
10101006
from torch.distributed.fsdp._flat_param import FlatParamHandle
10111007
original_unshard = FlatParamHandle.unshard

composer/utils/checkpoint.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -623,50 +623,41 @@ def dist_cp_load(
623623
load_planner: Optional[LoadPlanner] = None,
624624
):
625625
if version.parse(torch.__version__) >= version.parse('2.4.0'):
626-
if version.parse(torch.__version__) < version.parse('2.4.1'):
627-
# PyTorch 2.4.0
628-
from torch.distributed.checkpoint.utils import CheckpointException
629-
try:
630-
dist_cp.load(
631-
state_dict=state_dict,
632-
storage_reader=storage_reader,
633-
planner=load_planner,
634-
)
635-
except CheckpointException as e:
636-
checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata
637-
if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata:
638-
# Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
639-
# Torch issue: https://github.com/pytorch/pytorch/issues/133923.
640-
# We override the traverse_state_dict so that the load planner could
641-
# use the old way of flattening the state dict
642-
log.debug('Trying to load checkpointing saved before torch 2.4')
643-
644-
import torch.distributed.checkpoint._nested_dict as nested_dict
645-
import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util
646-
from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0
647-
648-
from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse
649-
650-
nested_dict.traverse_state_dict = backward_compatible_traverse
651-
sharded_tensor_util.traverse_state_dict = backward_compatible_traverse
652-
653-
dist_cp.load(
654-
state_dict=state_dict,
655-
storage_reader=storage_reader,
656-
planner=load_planner,
657-
)
658-
# Revert the override
659-
nested_dict.traverse_state_dict = traverse_2_4_0
660-
sharded_tensor_util.traverse_state_dict = traverse_2_4_0
661-
else:
662-
raise e
663-
else:
664-
# PyTorch 2.4.1
626+
from torch.distributed.checkpoint.utils import CheckpointException
627+
try:
665628
dist_cp.load(
666629
state_dict=state_dict,
667630
storage_reader=storage_reader,
668631
planner=load_planner,
669632
)
633+
except CheckpointException as e:
634+
checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata
635+
if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata:
636+
# Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
637+
# Torch issue: https://github.com/pytorch/pytorch/issues/133923.
638+
# We override the traverse_state_dict so that the load planner could
639+
# use the old way of flattening the state dict
640+
log.debug('Trying to load checkpointing saved before torch 2.4')
641+
642+
import torch.distributed.checkpoint._nested_dict as nested_dict
643+
import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util
644+
from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0
645+
646+
from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse
647+
648+
nested_dict.traverse_state_dict = backward_compatible_traverse
649+
sharded_tensor_util.traverse_state_dict = backward_compatible_traverse
650+
651+
dist_cp.load(
652+
state_dict=state_dict,
653+
storage_reader=storage_reader,
654+
planner=load_planner,
655+
)
656+
# Revert the override
657+
nested_dict.traverse_state_dict = traverse_2_4_0
658+
sharded_tensor_util.traverse_state_dict = traverse_2_4_0
659+
else:
660+
raise e
670661
else:
671662
dist_cp.load_state_dict(
672663
state_dict=state_dict,

0 commit comments

Comments
 (0)