diff --git a/composer/cli/launcher.py b/composer/cli/launcher.py index 91110c2add..2c22a8cf6c 100755 --- a/composer/cli/launcher.py +++ b/composer/cli/launcher.py @@ -19,7 +19,6 @@ import psutil import torch -from packaging import version import composer from composer.loggers.mosaicml_logger import ( @@ -321,10 +320,7 @@ def _launch_processes( log.info('Distributed KV store: tcp://%s:%s', master_addr, master_port) nccl_env_variable = { - ( - 'NCCL_ASYNC_ERROR_HANDLING' if version.parse(torch.__version__) < version.parse('2.2.0') else 'TORCH_NCCL_ASYNC_ERROR_HANDLING' - ): - '1', + 'TORCH_NCCL_ASYNC_ERROR_HANDLING': '1', } for local_rank in range(nproc): diff --git a/composer/core/state.py b/composer/core/state.py index db652c8055..2e7ac11303 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -11,13 +11,20 @@ from collections import OrderedDict from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Sequence, Union, cast -from unittest.mock import MagicMock import numpy as np import torch import torch.nn.modules.utils from packaging import version +from torch.amp.grad_scaler import GradScaler from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + set_model_state_dict, + set_optimizer_state_dict, +) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullOptimStateDictConfig, @@ -31,11 +38,6 @@ from torch.utils.data import DataLoader, Dataset from torchmetrics import Metric -if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import GradScaler # type: ignore -else: - from torch.cuda.amp.grad_scaler import GradScaler # type: ignore - from composer.core.data_spec import DataSpec from composer.core.event import Event from composer.core.precision import Precision @@ -126,6 +128,10 @@ def fsdp_get_optim_state_dict( ) -> dict[str, Any]: """Materializes a given model's optimizer's state_dict. + .. warning:: + This function is deprecated and will be removed in Composer version 0.32. + It is maintained for backwards compatibility with tests. + Args: model (torch.nn.Module): The model that the optimizer corresponds to. optim (torch.optim.Optimizer): The optimizer that you want a state dict for. @@ -140,6 +146,12 @@ def fsdp_get_optim_state_dict( Returns: dict[str, Any]: The state_dict for the given optimizer. """ + warnings.warn( + 'fsdp_get_optim_state_dict is deprecated and will be removed in Composer version 0.32', + DeprecationWarning, + stacklevel=2, + ) + with fsdp_state_dict_type_context(module=model, state_dict_type=state_dict_type): return FSDP.optim_state_dict(model, optim) # type: ignore @@ -200,10 +212,6 @@ def _create_device_mesh( fsdp_config: Optional[FSDPConfig | FSDP2Config], tp_config: Optional[TPConfig], ) -> Optional[DeviceMesh]: - if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'): - # Device mesh has correctness issues before torch 2.3.0 - return None - if fsdp_config is None: return None @@ -986,37 +994,22 @@ def get_model_state_dict(self) -> dict[str, Any]: Returns: dict[str, Any]: The state dict for the model. """ - if version.parse(torch.__version__) >= version.parse('2.4.0') or ( - version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized() - ): - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict - if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: - raise NotImplementedError( - textwrap.dedent( - f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' - f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set ' - 'fsdp_state_dict_type to None, "full", or "sharded".', - ), - ) - - model_state_dict = get_model_state_dict( - model=self.model, - submodules=None, - options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type == 'full', - cpu_offload=self.fsdp_enabled, + if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: + raise NotImplementedError( + textwrap.dedent( + f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' + 'torch version >= 2.4.0. Please set fsdp_state_dict_type to None, "full", or "sharded".', ), ) - else: - if self.fsdp_enabled and self.fsdp_state_dict_type is not None: - with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): - model_state_dict = self.model.state_dict() - else: - model_state_dict = self.model.state_dict() - # If model is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail - if self.is_model_ddp: - torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.') + model_state_dict = get_model_state_dict( + model=self.model, + submodules=None, + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + cpu_offload=self.fsdp_enabled, + ), + ) return model_state_dict @@ -1026,40 +1019,25 @@ def get_optim_state_dict(self) -> dict[str, Any]: Returns: dict[str, Any]: The state dict for the optimizer. """ - if version.parse(torch.__version__) >= version.parse('2.4.0') or ( - version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized() - ): - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict - if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: - raise NotImplementedError( - textwrap.dedent( - f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' - f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set ' - 'fsdp_state_dict_type to None, "full", or "sharded".', - ), - ) - - optimizer = ensure_tuple(self.optimizers)[0] - optim_state_dict = get_optimizer_state_dict( - model=self.model, - optimizers=optimizer, - submodules=None, - options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type == 'full', - cpu_offload=self.fsdp_enabled, + if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: + raise NotImplementedError( + textwrap.dedent( + f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' + 'torch version >= 2.4.0. Please set fsdp_state_dict_type to None, "full", or "sharded".', ), ) - return {type(optimizer).__qualname__: optim_state_dict} - else: - optimizer = ensure_tuple(self.optimizers)[0] - if self.fsdp_enabled and self.fsdp_state_dict_type is not None: - optim_state_dict = { - type(optimizer).__qualname__: - fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type), - } - else: - optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()} - return optim_state_dict + + optimizer = ensure_tuple(self.optimizers)[0] + optim_state_dict = get_optimizer_state_dict( + model=self.model, + optimizers=optimizer, + submodules=None, + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + cpu_offload=self.fsdp_enabled, + ), + ) + return {type(optimizer).__qualname__: optim_state_dict} def state_dict(self) -> dict[str, Any]: """Collect the state dicts of our serializable attributes. @@ -1338,67 +1316,31 @@ def load_model_state( model_on_rank = state_dict['model'] is not None if model_on_rank: - if version.parse(torch.__version__) >= version.parse('2.4.0') or ( - version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized() - ): - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - try: - set_model_state_dict( - model=self.model, - model_state_dict=state_dict['model'], - options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type == 'full', - strict=strict, - cpu_offload=self.fsdp_enabled, + try: + set_model_state_dict( + model=self.model, + model_state_dict=state_dict['model'], + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + strict=strict, + cpu_offload=self.fsdp_enabled, + ), + ) + except AttributeError as e: + # Issue: https://github.com/pytorch/pytorch/issues/127351 + if "ShardedTensor' object has no attribute 'placements'" in str(e): + raise RuntimeError( + textwrap.dedent( + 'PyTorch DTensor broke backwards compatibility in older checkpoints ' + 'with ShardedTensor, which is now deprecated. To load old checkpoints, ' + 'either downgrade to PyTorch <2.3.0 or explicitly pass process groups ' + 'in the Trainer constructor via ' + "`parallelism_config = {'fsdp': {'process_group': 'mod1'}}`. We can " + 'provide assistance at https://github.com/mosaicml/composer/issues.', ), - ) - except AttributeError as e: - # Issue: https://github.com/pytorch/pytorch/issues/127351 - if "ShardedTensor' object has no attribute 'placements'" in str(e): - raise RuntimeError( - textwrap.dedent( - 'PyTorch DTensor broke backwards compatibility in older checkpoints ' - 'with ShardedTensor, which is now deprecated. To load old checkpoints, ' - 'either downgrade to PyTorch <2.3.0 or explicitly pass process groups ' - 'in the Trainer constructor via ' - "`parallelism_config = {'fsdp': {'process_group': 'mod1'}}`. We can " - 'provide assistance at https://github.com/mosaicml/composer/issues.', - ), - ) from e - else: - raise e - else: - missing_keys, unexpected_keys = [], [] - try: - # Load model if it exists - if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_monolith_rank0_only: - log.debug( - f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}', - ) - with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): - missing_keys, unexpected_keys = self.model.load_state_dict( - state_dict['model'], - strict=strict, - ) - else: - log.debug(f'Loading model state dict with strict={strict}') - missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict) - except RuntimeError as e: - if 'Missing key(s) in state_dict' in str(e) or 'Unexpected key(s) in state_dict' in str(e): - raise RuntimeError( - textwrap.dedent( - 'Failed to load checkpoint due to missing or unexpected keys in state_dict. ' - 'This is likely due to a change in the model architecture. If this is intentional, ' - 'you can set load_strict_model_weights=False in the Trainer.', - ), - ) from e - else: - raise e - - if len(missing_keys) > 0: - log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + ) from e + else: + raise e # If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading if self.load_monolith_rank0_only: @@ -1448,52 +1390,17 @@ def load_optim_state(self, state_dict: dict[str, Any], strict: bool = True): continue optim_state_dict = serialized_value[type(optimizer).__qualname__] if serialized_value is not None else None - if version.parse(torch.__version__) >= version.parse('2.4.0') or ( - version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized() - ): - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict - - # optim_state_dict is `None` on non-zero ranks when loading FSDP monolith - # checkpoint on rank 0 only. However, PyTorch modifies the state_dict (producing - # errors) before discarding the output. Accordingly, we mock the state dict. - # See: https://github.com/pytorch/pytorch/issues/125177 - if version.parse(torch.__version__) < version.parse('2.4.0'): - optim_state_dict = MagicMock() if optim_state_dict is None else optim_state_dict - set_optimizer_state_dict( - model=self.model, - optimizers=optimizer, - optim_state_dict=optim_state_dict, - options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type == 'full', - strict=strict, - cpu_offload=self.fsdp_enabled, - ), - ) - else: - if self.fsdp_enabled: - assert self.fsdp_state_dict_type is not None # pyright - log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}') - # Loading FSDP monolith on rank 0 only requires FSDP.scatter_full_optim_state_dict - # as the context manager does not seem to pass rank0_only=True for the optimizer config - if self.load_monolith_rank0_only: - optim_state_dict = _legacy_optim_state_dict_to_load( - optim_state_dict=optim_state_dict, - model=self.model, - optim=optimizer, - state_dict_type=self.fsdp_state_dict_type, - ) - else: - assert optim_state_dict is not None - with fsdp_state_dict_type_context(module=self.model, state_dict_type=self.fsdp_state_dict_type): - optim_state_dict = FSDP.optim_state_dict_to_load( # type: ignore - optim_state_dict=optim_state_dict, model=self.model, optim=optimizer, - ) - assert optim_state_dict is not None - optimizer.load_state_dict(optim_state_dict) - else: - assert optim_state_dict is not None - log.debug(f'Loading optimizer state dict') - optimizer.load_state_dict(optim_state_dict) + + set_optimizer_state_dict( + model=self.model, + optimizers=optimizer, + optim_state_dict=optim_state_dict, # type: ignore + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + strict=strict, + cpu_offload=self.fsdp_enabled, + ), + ) def load_state_dict( self, diff --git a/composer/trainer/_scaler.py b/composer/trainer/_scaler.py index f84a3ee30d..b67d489781 100644 --- a/composer/trainer/_scaler.py +++ b/composer/trainer/_scaler.py @@ -5,15 +5,9 @@ from typing import Optional, Union import torch -from packaging import version -from torch.cuda.amp.grad_scaler import GradScaler, OptState +from torch.amp.grad_scaler import GradScaler, OptState, _refresh_per_optimizer_state from torch.optim import Optimizer -if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore -else: - from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore - from composer.utils import dist __all__ = ['ClosureGradScaler'] diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 6bf2f1e868..3c5a7c13a1 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -40,6 +40,7 @@ import torch.utils.data from packaging import version from torch._dynamo import OptimizedModule +from torch.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -48,11 +49,6 @@ from torch.utils.data import DataLoader, DistributedSampler from torchmetrics import Metric -if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state # type: ignore -else: - from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state # type: ignore - from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver from composer.core import ( Algorithm, @@ -2169,8 +2165,7 @@ def fit( raise ValueError( 'PyTorch 2.4 breaks (distributed) checkpointing with SGD. ' 'Please use a different optimizer, e.g. composer.optim.DecoupledSGDW, ' - 'instead or downgrade to PyTorch <2.4. See ', - 'https://github.com/pytorch/pytorch/issues/133415 for further information.', + 'instead. See https://github.com/pytorch/pytorch/issues/133415 for further information.', ) if self.state.max_duration is None: diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 085b8dcbb6..e1949e03af 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -20,13 +20,14 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch -from packaging import version from torch.distributed import checkpoint as dist_cp from torch.distributed._tensor import DeviceMesh +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner from torch.distributed.checkpoint.storage import StorageReader +from torch.distributed.checkpoint.utils import CheckpointException from torch.distributed.distributed_c10d import ProcessGroup from composer.utils import dist, reproducibility @@ -558,64 +559,55 @@ def dist_cp_load( storage_reader: StorageReader, load_planner: Optional[LoadPlanner] = None, ): - if version.parse(torch.__version__) >= version.parse('2.4.0'): - from torch.distributed.checkpoint.utils import CheckpointException - try: + try: + dist_cp.load( + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) + except CheckpointException as e: + if re.search(r'Size mismatch.+torch\.Size\(\[816\]\).+rng\.\d+\.cuda', str(e)) is not None: + # Pytorch 2.6 is strictly enforcing the sizes of the state dict values vs the size + # in the checkpoint. However Pytorch starting in 2.1.0 + # have moved from using RNG of size 816 to 16, therefore causing errors if we + # load a ckpt at or after 2.6.0 if the ckpt was created before 2.1.0 + for idx in range(len(state_dict['rng'])): + state_dict['rng'][idx]['cuda'] = torch.zeros(816, dtype=torch.uint8) + dist_cp.load( state_dict=state_dict, storage_reader=storage_reader, planner=load_planner, ) - except CheckpointException as e: - if re.search(r'Size mismatch.+torch\.Size\(\[816\]\).+rng\.\d+\.cuda', str(e)) is not None: - # Pytorch 2.6 is strictly enforcing the sizes of the state dict values vs the size - # in the checkpoint. However Pytorch starting in 2.1.0 - # have moved from using RNG of size 816 to 16, therefore causing errors if we - # load a ckpt at or after 2.6.0 if the ckpt was created before 2.1.0 - for idx in range(len(state_dict['rng'])): - state_dict['rng'][idx]['cuda'] = torch.zeros(816, dtype=torch.uint8) - - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - return + return - checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata - if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata: - # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility. - # Torch issue: https://github.com/pytorch/pytorch/issues/133923. - # We override the traverse_state_dict so that the load planner could - # use the old way of flattening the state dict - log.debug('Trying to load checkpointing saved before torch 2.4') + checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata + if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata: + # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility. + # Torch issue: https://github.com/pytorch/pytorch/issues/133923. + # We override the traverse_state_dict so that the load planner could + # use the old way of flattening the state dict + log.debug('Trying to load checkpointing saved before torch 2.4') - import torch.distributed.checkpoint._nested_dict as nested_dict - import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util - from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0 + import torch.distributed.checkpoint._nested_dict as nested_dict + import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util + from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0 - from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse + from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse - nested_dict.traverse_state_dict = backward_compatible_traverse - sharded_tensor_util.traverse_state_dict = backward_compatible_traverse + nested_dict.traverse_state_dict = backward_compatible_traverse + sharded_tensor_util.traverse_state_dict = backward_compatible_traverse - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - # Revert the override - nested_dict.traverse_state_dict = traverse_2_4_0 - sharded_tensor_util.traverse_state_dict = traverse_2_4_0 - else: - raise e - else: - dist_cp.load_state_dict( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - no_dist=(not dist.is_initialized()), - ) + dist_cp.load( + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) + # Revert the override + nested_dict.traverse_state_dict = traverse_2_4_0 + sharded_tensor_util.traverse_state_dict = traverse_2_4_0 + else: + raise e def load_sharded_checkpoint( @@ -630,12 +622,6 @@ def load_sharded_checkpoint( exclude_algorithms: Optional[list[str]] = None, algorithm_passes: Optional[list[AlgorithmPass]] = None, ) -> Union[list[dict], None]: - using_multinode = dist.get_world_size() != dist.get_local_world_size() - if not version.parse(torch.__version__) >= version.parse('2.0.1') and using_multinode: - raise ValueError( - f'Sharded checkpoint loading on >1 node requires torch version >= 2.0.1. You have torch version {torch.__version__}', - ) - if state.fsdp_config is None: raise ValueError('Loading a sharded checkpoint requires passing an FSDP config to Trainer.') @@ -1039,7 +1025,7 @@ def get_save_filename( assert remote_prefix is not None save_dirpath = Path(Path(filename).parent) / Path(remote_prefix) save_dirpath = format_name_with_dist_and_time(str(save_dirpath), state.run_name, state.timestamp) - # New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcp’ + # New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcp' # e.g. path/to/my/checkpoints/ep1-ba2/__1_0.distcp ckpt_filename = _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME return str(Path(save_dirpath) / Path(ckpt_filename)) @@ -1080,13 +1066,6 @@ def _save_checkpoint( # Only rank 0 saves RNG if dist.get_global_rank() > 0: state_dict.pop('rng') - # To load optimizer states with 2.0 <= torch < 2.2.3 , the optimizer state must be at the top - # level of the state dict because the load_sharded_optimizer_state_dict function - # requires a top level state dict key for the optimizer. - # See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271 - # for more info. - if version.parse(torch.__version__) < version.parse('2.2.3'): - state_dict['optimizers'] = state_dict['state'].pop('optimizers') log.debug('State dict created.') dirname = os.path.dirname(save_filename) @@ -1116,31 +1095,15 @@ def _save_checkpoint( expect_file = True if expect_file: - if version.parse(torch.__version__) >= version.parse('2.3.0'): - save_planner = state.fsdp_config.save_planner - if save_planner is None: - if version.parse(torch.__version__) < version.parse('2.4.0'): - # Dedup is only broken on <2.4 - from composer.trainer._patch_pytorch import SavePlannerWithDedupFix - - save_planner = SavePlannerWithDedupFix() - else: - from torch.distributed.checkpoint.default_planner import DefaultSavePlanner - - save_planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) - dist_cp.save( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(dirname), - planner=save_planner, - process_group=process_group, - ) - else: - dist_cp.save_state_dict( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(dirname), - planner=state.fsdp_config.save_planner, - process_group=process_group, - ) + save_planner = state.fsdp_config.save_planner + if save_planner is None: + save_planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) + dist_cp.save( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(dirname), + planner=save_planner, + process_group=process_group, + ) log.debug('Finished pytorch save state dict') # Save monolith checkpoint elif dist.get_global_rank() == 0: