@@ -623,50 +623,41 @@ def dist_cp_load(
623
623
load_planner : Optional [LoadPlanner ] = None ,
624
624
):
625
625
if version .parse (torch .__version__ ) >= version .parse ('2.4.0' ):
626
- if version .parse (torch .__version__ ) < version .parse ('2.4.1' ):
627
- # PyTorch 2.4.0
628
- from torch .distributed .checkpoint .utils import CheckpointException
629
- try :
630
- dist_cp .load (
631
- state_dict = state_dict ,
632
- storage_reader = storage_reader ,
633
- planner = load_planner ,
634
- )
635
- except CheckpointException as e :
636
- checkpoint_metadata = storage_reader .read_metadata ().state_dict_metadata
637
- if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata :
638
- # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
639
- # Torch issue: https://github.com/pytorch/pytorch/issues/133923.
640
- # We override the traverse_state_dict so that the load planner could
641
- # use the old way of flattening the state dict
642
- log .debug ('Trying to load checkpointing saved before torch 2.4' )
643
-
644
- import torch .distributed .checkpoint ._nested_dict as nested_dict
645
- import torch .distributed .checkpoint ._sharded_tensor_utils as sharded_tensor_util
646
- from torch .distributed .checkpoint ._traverse import traverse_state_dict as traverse_2_4_0
647
-
648
- from composer .trainer ._patch_pytorch import traverse_state_dict as backward_compatible_traverse
649
-
650
- nested_dict .traverse_state_dict = backward_compatible_traverse
651
- sharded_tensor_util .traverse_state_dict = backward_compatible_traverse
652
-
653
- dist_cp .load (
654
- state_dict = state_dict ,
655
- storage_reader = storage_reader ,
656
- planner = load_planner ,
657
- )
658
- # Revert the override
659
- nested_dict .traverse_state_dict = traverse_2_4_0
660
- sharded_tensor_util .traverse_state_dict = traverse_2_4_0
661
- else :
662
- raise e
663
- else :
664
- # PyTorch 2.4.1
626
+ from torch .distributed .checkpoint .utils import CheckpointException
627
+ try :
665
628
dist_cp .load (
666
629
state_dict = state_dict ,
667
630
storage_reader = storage_reader ,
668
631
planner = load_planner ,
669
632
)
633
+ except CheckpointException as e :
634
+ checkpoint_metadata = storage_reader .read_metadata ().state_dict_metadata
635
+ if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata :
636
+ # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
637
+ # Torch issue: https://github.com/pytorch/pytorch/issues/133923.
638
+ # We override the traverse_state_dict so that the load planner could
639
+ # use the old way of flattening the state dict
640
+ log .debug ('Trying to load checkpointing saved before torch 2.4' )
641
+
642
+ import torch .distributed .checkpoint ._nested_dict as nested_dict
643
+ import torch .distributed .checkpoint ._sharded_tensor_utils as sharded_tensor_util
644
+ from torch .distributed .checkpoint ._traverse import traverse_state_dict as traverse_2_4_0
645
+
646
+ from composer .trainer ._patch_pytorch import traverse_state_dict as backward_compatible_traverse
647
+
648
+ nested_dict .traverse_state_dict = backward_compatible_traverse
649
+ sharded_tensor_util .traverse_state_dict = backward_compatible_traverse
650
+
651
+ dist_cp .load (
652
+ state_dict = state_dict ,
653
+ storage_reader = storage_reader ,
654
+ planner = load_planner ,
655
+ )
656
+ # Revert the override
657
+ nested_dict .traverse_state_dict = traverse_2_4_0
658
+ sharded_tensor_util .traverse_state_dict = traverse_2_4_0
659
+ else :
660
+ raise e
670
661
else :
671
662
dist_cp .load_state_dict (
672
663
state_dict = state_dict ,
0 commit comments