From 01e83a0ba9451320649263dc15e5ead67a0c2fc8 Mon Sep 17 00:00:00 2001 From: bowenyang008 Date: Fri, 25 Apr 2025 21:03:01 +0000 Subject: [PATCH 1/9] clean up state --- composer/core/state.py | 265 ++++++++++++++--------------------------- 1 file changed, 90 insertions(+), 175 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index db652c8055..b6270e551b 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -11,7 +11,6 @@ from collections import OrderedDict from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Sequence, Union, cast -from unittest.mock import MagicMock import numpy as np import torch @@ -31,10 +30,7 @@ from torch.utils.data import DataLoader, Dataset from torchmetrics import Metric -if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import GradScaler # type: ignore -else: - from torch.cuda.amp.grad_scaler import GradScaler # type: ignore +from torch.amp.grad_scaler import GradScaler from composer.core.data_spec import DataSpec from composer.core.event import Event @@ -72,6 +68,10 @@ @contextmanager def fsdp_state_dict_type_context(module: torch.nn.Module, state_dict_type: str = 'full'): """Context manager for materializing or loading an fsdp module's state dict. + + .. warning:: + This function is deprecated and will be removed in a future release. + It is maintained for backwards compatibility with tests. Args: module (torch.nn.Module): The torch module that you want to call `state_dict()` @@ -87,6 +87,12 @@ def fsdp_state_dict_type_context(module: torch.nn.Module, state_dict_type: str = """ # Torch forgot to put ShardedStateDictConfig in torch/distributed/fsdp/__init__.py, so we # have to import it this way. + warnings.warn( + "fsdp_state_dict_type_context is deprecated and will be removed in a future release", + DeprecationWarning, + stacklevel=2, + ) + from torch.distributed.fsdp.fully_sharded_data_parallel import ShardedStateDictConfig fsdp_state_dict_type = None @@ -125,6 +131,10 @@ def fsdp_get_optim_state_dict( state_dict_type: str = 'full', ) -> dict[str, Any]: """Materializes a given model's optimizer's state_dict. + + .. warning:: + This function is deprecated and will be removed in a future release. + It is maintained for backwards compatibility with tests. Args: model (torch.nn.Module): The model that the optimizer corresponds to. @@ -140,6 +150,12 @@ def fsdp_get_optim_state_dict( Returns: dict[str, Any]: The state_dict for the given optimizer. """ + warnings.warn( + "fsdp_get_optim_state_dict is deprecated and will be removed in a future release", + DeprecationWarning, + stacklevel=2, + ) + with fsdp_state_dict_type_context(module=model, state_dict_type=state_dict_type): return FSDP.optim_state_dict(model, optim) # type: ignore @@ -200,10 +216,6 @@ def _create_device_mesh( fsdp_config: Optional[FSDPConfig | FSDP2Config], tp_config: Optional[TPConfig], ) -> Optional[DeviceMesh]: - if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'): - # Device mesh has correctness issues before torch 2.3.0 - return None - if fsdp_config is None: return None @@ -986,37 +998,23 @@ def get_model_state_dict(self) -> dict[str, Any]: Returns: dict[str, Any]: The state dict for the model. """ - if version.parse(torch.__version__) >= version.parse('2.4.0') or ( - version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized() - ): - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict - if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: - raise NotImplementedError( - textwrap.dedent( - f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' - f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set ' - 'fsdp_state_dict_type to None, "full", or "sharded".', - ), - ) - - model_state_dict = get_model_state_dict( - model=self.model, - submodules=None, - options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type == 'full', - cpu_offload=self.fsdp_enabled, + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: + raise NotImplementedError( + textwrap.dedent( + f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' + 'torch version >= 2.4.0. Please set fsdp_state_dict_type to None, "full", or "sharded".', ), ) - else: - if self.fsdp_enabled and self.fsdp_state_dict_type is not None: - with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): - model_state_dict = self.model.state_dict() - else: - model_state_dict = self.model.state_dict() - # If model is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail - if self.is_model_ddp: - torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.') + model_state_dict = get_model_state_dict( + model=self.model, + submodules=None, + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + cpu_offload=self.fsdp_enabled, + ), + ) return model_state_dict @@ -1026,40 +1024,26 @@ def get_optim_state_dict(self) -> dict[str, Any]: Returns: dict[str, Any]: The state dict for the optimizer. """ - if version.parse(torch.__version__) >= version.parse('2.4.0') or ( - version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized() - ): - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict - if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: - raise NotImplementedError( - textwrap.dedent( - f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' - f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set ' - 'fsdp_state_dict_type to None, "full", or "sharded".', - ), - ) - - optimizer = ensure_tuple(self.optimizers)[0] - optim_state_dict = get_optimizer_state_dict( - model=self.model, - optimizers=optimizer, - submodules=None, - options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type == 'full', - cpu_offload=self.fsdp_enabled, + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict + if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: + raise NotImplementedError( + textwrap.dedent( + f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' + 'torch version >= 2.4.0. Please set fsdp_state_dict_type to None, "full", or "sharded".', ), ) - return {type(optimizer).__qualname__: optim_state_dict} - else: - optimizer = ensure_tuple(self.optimizers)[0] - if self.fsdp_enabled and self.fsdp_state_dict_type is not None: - optim_state_dict = { - type(optimizer).__qualname__: - fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type), - } - else: - optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()} - return optim_state_dict + + optimizer = ensure_tuple(self.optimizers)[0] + optim_state_dict = get_optimizer_state_dict( + model=self.model, + optimizers=optimizer, + submodules=None, + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + cpu_offload=self.fsdp_enabled, + ), + ) + return {type(optimizer).__qualname__: optim_state_dict} def state_dict(self) -> dict[str, Any]: """Collect the state dicts of our serializable attributes. @@ -1338,67 +1322,32 @@ def load_model_state( model_on_rank = state_dict['model'] is not None if model_on_rank: - if version.parse(torch.__version__) >= version.parse('2.4.0') or ( - version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized() - ): - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - try: - set_model_state_dict( - model=self.model, - model_state_dict=state_dict['model'], - options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type == 'full', - strict=strict, - cpu_offload=self.fsdp_enabled, + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + try: + set_model_state_dict( + model=self.model, + model_state_dict=state_dict['model'], + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + strict=strict, + cpu_offload=self.fsdp_enabled, + ), + ) + except AttributeError as e: + # Issue: https://github.com/pytorch/pytorch/issues/127351 + if "ShardedTensor' object has no attribute 'placements'" in str(e): + raise RuntimeError( + textwrap.dedent( + 'PyTorch DTensor broke backwards compatibility in older checkpoints ' + 'with ShardedTensor, which is now deprecated. To load old checkpoints, ' + 'either downgrade to PyTorch <2.3.0 or explicitly pass process groups ' + 'in the Trainer constructor via ' + "`parallelism_config = {'fsdp': {'process_group': 'mod1'}}`. We can " + 'provide assistance at https://github.com/mosaicml/composer/issues.', ), - ) - except AttributeError as e: - # Issue: https://github.com/pytorch/pytorch/issues/127351 - if "ShardedTensor' object has no attribute 'placements'" in str(e): - raise RuntimeError( - textwrap.dedent( - 'PyTorch DTensor broke backwards compatibility in older checkpoints ' - 'with ShardedTensor, which is now deprecated. To load old checkpoints, ' - 'either downgrade to PyTorch <2.3.0 or explicitly pass process groups ' - 'in the Trainer constructor via ' - "`parallelism_config = {'fsdp': {'process_group': 'mod1'}}`. We can " - 'provide assistance at https://github.com/mosaicml/composer/issues.', - ), - ) from e - else: - raise e - else: - missing_keys, unexpected_keys = [], [] - try: - # Load model if it exists - if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_monolith_rank0_only: - log.debug( - f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}', - ) - with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): - missing_keys, unexpected_keys = self.model.load_state_dict( - state_dict['model'], - strict=strict, - ) - else: - log.debug(f'Loading model state dict with strict={strict}') - missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict) - except RuntimeError as e: - if 'Missing key(s) in state_dict' in str(e) or 'Unexpected key(s) in state_dict' in str(e): - raise RuntimeError( - textwrap.dedent( - 'Failed to load checkpoint due to missing or unexpected keys in state_dict. ' - 'This is likely due to a change in the model architecture. If this is intentional, ' - 'you can set load_strict_model_weights=False in the Trainer.', - ), - ) from e - else: - raise e - - if len(missing_keys) > 0: - log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + ) from e + else: + raise e # If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading if self.load_monolith_rank0_only: @@ -1448,52 +1397,18 @@ def load_optim_state(self, state_dict: dict[str, Any], strict: bool = True): continue optim_state_dict = serialized_value[type(optimizer).__qualname__] if serialized_value is not None else None - if version.parse(torch.__version__) >= version.parse('2.4.0') or ( - version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized() - ): - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict - - # optim_state_dict is `None` on non-zero ranks when loading FSDP monolith - # checkpoint on rank 0 only. However, PyTorch modifies the state_dict (producing - # errors) before discarding the output. Accordingly, we mock the state dict. - # See: https://github.com/pytorch/pytorch/issues/125177 - if version.parse(torch.__version__) < version.parse('2.4.0'): - optim_state_dict = MagicMock() if optim_state_dict is None else optim_state_dict - set_optimizer_state_dict( - model=self.model, - optimizers=optimizer, - optim_state_dict=optim_state_dict, - options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type == 'full', - strict=strict, - cpu_offload=self.fsdp_enabled, - ), - ) - else: - if self.fsdp_enabled: - assert self.fsdp_state_dict_type is not None # pyright - log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}') - # Loading FSDP monolith on rank 0 only requires FSDP.scatter_full_optim_state_dict - # as the context manager does not seem to pass rank0_only=True for the optimizer config - if self.load_monolith_rank0_only: - optim_state_dict = _legacy_optim_state_dict_to_load( - optim_state_dict=optim_state_dict, - model=self.model, - optim=optimizer, - state_dict_type=self.fsdp_state_dict_type, - ) - else: - assert optim_state_dict is not None - with fsdp_state_dict_type_context(module=self.model, state_dict_type=self.fsdp_state_dict_type): - optim_state_dict = FSDP.optim_state_dict_to_load( # type: ignore - optim_state_dict=optim_state_dict, model=self.model, optim=optimizer, - ) - assert optim_state_dict is not None - optimizer.load_state_dict(optim_state_dict) - else: - assert optim_state_dict is not None - log.debug(f'Loading optimizer state dict') - optimizer.load_state_dict(optim_state_dict) + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict + + set_optimizer_state_dict( + model=self.model, + optimizers=optimizer, + optim_state_dict=optim_state_dict, + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + strict=strict, + cpu_offload=self.fsdp_enabled, + ), + ) def load_state_dict( self, From 3cf4e30282b0c229902941d3bc9a46b0c7afd28d Mon Sep 17 00:00:00 2001 From: bowenyang008 Date: Fri, 25 Apr 2025 22:39:01 +0000 Subject: [PATCH 2/9] clean up --- composer/core/state.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index b6270e551b..6525494c30 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -16,7 +16,15 @@ import torch import torch.nn.modules.utils from packaging import version +from torch.amp.grad_scaler import GradScaler from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + set_model_state_dict, + set_optimizer_state_dict, +) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullOptimStateDictConfig, @@ -30,8 +38,6 @@ from torch.utils.data import DataLoader, Dataset from torchmetrics import Metric -from torch.amp.grad_scaler import GradScaler - from composer.core.data_spec import DataSpec from composer.core.event import Event from composer.core.precision import Precision @@ -68,7 +74,7 @@ @contextmanager def fsdp_state_dict_type_context(module: torch.nn.Module, state_dict_type: str = 'full'): """Context manager for materializing or loading an fsdp module's state dict. - + .. warning:: This function is deprecated and will be removed in a future release. It is maintained for backwards compatibility with tests. @@ -88,11 +94,11 @@ def fsdp_state_dict_type_context(module: torch.nn.Module, state_dict_type: str = # Torch forgot to put ShardedStateDictConfig in torch/distributed/fsdp/__init__.py, so we # have to import it this way. warnings.warn( - "fsdp_state_dict_type_context is deprecated and will be removed in a future release", + 'fsdp_state_dict_type_context is deprecated and will be removed in a future release', DeprecationWarning, stacklevel=2, ) - + from torch.distributed.fsdp.fully_sharded_data_parallel import ShardedStateDictConfig fsdp_state_dict_type = None @@ -131,7 +137,7 @@ def fsdp_get_optim_state_dict( state_dict_type: str = 'full', ) -> dict[str, Any]: """Materializes a given model's optimizer's state_dict. - + .. warning:: This function is deprecated and will be removed in a future release. It is maintained for backwards compatibility with tests. @@ -151,11 +157,11 @@ def fsdp_get_optim_state_dict( dict[str, Any]: The state_dict for the given optimizer. """ warnings.warn( - "fsdp_get_optim_state_dict is deprecated and will be removed in a future release", + 'fsdp_get_optim_state_dict is deprecated and will be removed in a future release', DeprecationWarning, stacklevel=2, ) - + with fsdp_state_dict_type_context(module=model, state_dict_type=state_dict_type): return FSDP.optim_state_dict(model, optim) # type: ignore @@ -998,7 +1004,6 @@ def get_model_state_dict(self) -> dict[str, Any]: Returns: dict[str, Any]: The state dict for the model. """ - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: raise NotImplementedError( textwrap.dedent( @@ -1024,7 +1029,6 @@ def get_optim_state_dict(self) -> dict[str, Any]: Returns: dict[str, Any]: The state dict for the optimizer. """ - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: raise NotImplementedError( textwrap.dedent( @@ -1322,7 +1326,6 @@ def load_model_state( model_on_rank = state_dict['model'] is not None if model_on_rank: - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict try: set_model_state_dict( model=self.model, @@ -1397,12 +1400,11 @@ def load_optim_state(self, state_dict: dict[str, Any], strict: bool = True): continue optim_state_dict = serialized_value[type(optimizer).__qualname__] if serialized_value is not None else None - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict set_optimizer_state_dict( model=self.model, optimizers=optimizer, - optim_state_dict=optim_state_dict, + optim_state_dict=optim_state_dict, # type: ignore options=StateDictOptions( full_state_dict=self.fsdp_state_dict_type == 'full', strict=strict, From d8f4cacc6d3fec1e43906ad44c3befb53cc3c472 Mon Sep 17 00:00:00 2001 From: bowenyang008 Date: Mon, 28 Apr 2025 20:52:00 +0000 Subject: [PATCH 3/9] update checkpoint --- composer/utils/checkpoint.py | 123 +++++++++-------------------------- 1 file changed, 30 insertions(+), 93 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 085b8dcbb6..c6da987fdb 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -20,13 +20,14 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch -from packaging import version from torch.distributed import checkpoint as dist_cp from torch.distributed._tensor import DeviceMesh +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner from torch.distributed.checkpoint.storage import StorageReader +from torch.distributed.checkpoint.utils import CheckpointException from torch.distributed.distributed_c10d import ProcessGroup from composer.utils import dist, reproducibility @@ -558,64 +559,28 @@ def dist_cp_load( storage_reader: StorageReader, load_planner: Optional[LoadPlanner] = None, ): - if version.parse(torch.__version__) >= version.parse('2.4.0'): - from torch.distributed.checkpoint.utils import CheckpointException - try: + try: + dist_cp.load( + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) + except CheckpointException as e: + if re.search(r'Size mismatch.+torch\.Size\(\[816\]\).+rng\.\d+\.cuda', str(e)) is not None: + # Pytorch 2.6 is strictly enforcing the sizes of the state dict values vs the size + # in the checkpoint. However Pytorch starting in 2.1.0 + # have moved from using RNG of size 816 to 16, therefore causing errors if we + # load a ckpt at or after 2.6.0 if the ckpt was created before 2.1.0 + for idx in range(len(state_dict['rng'])): + state_dict['rng'][idx]['cuda'] = torch.zeros(816, dtype=torch.uint8) + dist_cp.load( state_dict=state_dict, storage_reader=storage_reader, planner=load_planner, ) - except CheckpointException as e: - if re.search(r'Size mismatch.+torch\.Size\(\[816\]\).+rng\.\d+\.cuda', str(e)) is not None: - # Pytorch 2.6 is strictly enforcing the sizes of the state dict values vs the size - # in the checkpoint. However Pytorch starting in 2.1.0 - # have moved from using RNG of size 816 to 16, therefore causing errors if we - # load a ckpt at or after 2.6.0 if the ckpt was created before 2.1.0 - for idx in range(len(state_dict['rng'])): - state_dict['rng'][idx]['cuda'] = torch.zeros(816, dtype=torch.uint8) - - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - return - - checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata - if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata: - # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility. - # Torch issue: https://github.com/pytorch/pytorch/issues/133923. - # We override the traverse_state_dict so that the load planner could - # use the old way of flattening the state dict - log.debug('Trying to load checkpointing saved before torch 2.4') - - import torch.distributed.checkpoint._nested_dict as nested_dict - import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util - from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0 - - from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse - - nested_dict.traverse_state_dict = backward_compatible_traverse - sharded_tensor_util.traverse_state_dict = backward_compatible_traverse - - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - # Revert the override - nested_dict.traverse_state_dict = traverse_2_4_0 - sharded_tensor_util.traverse_state_dict = traverse_2_4_0 - else: - raise e - else: - dist_cp.load_state_dict( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - no_dist=(not dist.is_initialized()), - ) + return + raise e def load_sharded_checkpoint( @@ -630,12 +595,6 @@ def load_sharded_checkpoint( exclude_algorithms: Optional[list[str]] = None, algorithm_passes: Optional[list[AlgorithmPass]] = None, ) -> Union[list[dict], None]: - using_multinode = dist.get_world_size() != dist.get_local_world_size() - if not version.parse(torch.__version__) >= version.parse('2.0.1') and using_multinode: - raise ValueError( - f'Sharded checkpoint loading on >1 node requires torch version >= 2.0.1. You have torch version {torch.__version__}', - ) - if state.fsdp_config is None: raise ValueError('Loading a sharded checkpoint requires passing an FSDP config to Trainer.') @@ -1039,7 +998,7 @@ def get_save_filename( assert remote_prefix is not None save_dirpath = Path(Path(filename).parent) / Path(remote_prefix) save_dirpath = format_name_with_dist_and_time(str(save_dirpath), state.run_name, state.timestamp) - # New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcp’ + # New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcp' # e.g. path/to/my/checkpoints/ep1-ba2/__1_0.distcp ckpt_filename = _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME return str(Path(save_dirpath) / Path(ckpt_filename)) @@ -1080,13 +1039,7 @@ def _save_checkpoint( # Only rank 0 saves RNG if dist.get_global_rank() > 0: state_dict.pop('rng') - # To load optimizer states with 2.0 <= torch < 2.2.3 , the optimizer state must be at the top - # level of the state dict because the load_sharded_optimizer_state_dict function - # requires a top level state dict key for the optimizer. - # See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271 - # for more info. - if version.parse(torch.__version__) < version.parse('2.2.3'): - state_dict['optimizers'] = state_dict['state'].pop('optimizers') + # For older versions of PyTorch, we don't need this special handling anymore log.debug('State dict created.') dirname = os.path.dirname(save_filename) @@ -1116,31 +1069,15 @@ def _save_checkpoint( expect_file = True if expect_file: - if version.parse(torch.__version__) >= version.parse('2.3.0'): - save_planner = state.fsdp_config.save_planner - if save_planner is None: - if version.parse(torch.__version__) < version.parse('2.4.0'): - # Dedup is only broken on <2.4 - from composer.trainer._patch_pytorch import SavePlannerWithDedupFix - - save_planner = SavePlannerWithDedupFix() - else: - from torch.distributed.checkpoint.default_planner import DefaultSavePlanner - - save_planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) - dist_cp.save( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(dirname), - planner=save_planner, - process_group=process_group, - ) - else: - dist_cp.save_state_dict( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(dirname), - planner=state.fsdp_config.save_planner, - process_group=process_group, - ) + save_planner = state.fsdp_config.save_planner + if save_planner is None: + save_planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) + dist_cp.save( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(dirname), + planner=save_planner, + process_group=process_group, + ) log.debug('Finished pytorch save state dict') # Save monolith checkpoint elif dist.get_global_rank() == 0: From 728de381df9f2b672bcab7f6cd218dccc90ec7e4 Mon Sep 17 00:00:00 2001 From: bowenyang008 Date: Mon, 28 Apr 2025 21:03:21 +0000 Subject: [PATCH 4/9] update trainer/scaler --- composer/trainer/_scaler.py | 7 +------ composer/trainer/trainer.py | 9 ++------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/composer/trainer/_scaler.py b/composer/trainer/_scaler.py index f84a3ee30d..98f478588d 100644 --- a/composer/trainer/_scaler.py +++ b/composer/trainer/_scaler.py @@ -5,14 +5,9 @@ from typing import Optional, Union import torch -from packaging import version -from torch.cuda.amp.grad_scaler import GradScaler, OptState from torch.optim import Optimizer +from torch.amp.grad_scaler import _refresh_per_optimizer_state, GradScaler, OptState -if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore -else: - from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore from composer.utils import dist diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 6bf2f1e868..3c5a7c13a1 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -40,6 +40,7 @@ import torch.utils.data from packaging import version from torch._dynamo import OptimizedModule +from torch.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -48,11 +49,6 @@ from torch.utils.data import DataLoader, DistributedSampler from torchmetrics import Metric -if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state # type: ignore -else: - from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state # type: ignore - from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver from composer.core import ( Algorithm, @@ -2169,8 +2165,7 @@ def fit( raise ValueError( 'PyTorch 2.4 breaks (distributed) checkpointing with SGD. ' 'Please use a different optimizer, e.g. composer.optim.DecoupledSGDW, ' - 'instead or downgrade to PyTorch <2.4. See ', - 'https://github.com/pytorch/pytorch/issues/133415 for further information.', + 'instead. See https://github.com/pytorch/pytorch/issues/133415 for further information.', ) if self.state.max_duration is None: From 9ca11a421b251a7a69833e402c2160dec0450a2e Mon Sep 17 00:00:00 2001 From: bowenyang008 Date: Mon, 28 Apr 2025 21:06:50 +0000 Subject: [PATCH 5/9] fix launcher --- composer/cli/launcher.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/composer/cli/launcher.py b/composer/cli/launcher.py index 91110c2add..2c22a8cf6c 100755 --- a/composer/cli/launcher.py +++ b/composer/cli/launcher.py @@ -19,7 +19,6 @@ import psutil import torch -from packaging import version import composer from composer.loggers.mosaicml_logger import ( @@ -321,10 +320,7 @@ def _launch_processes( log.info('Distributed KV store: tcp://%s:%s', master_addr, master_port) nccl_env_variable = { - ( - 'NCCL_ASYNC_ERROR_HANDLING' if version.parse(torch.__version__) < version.parse('2.2.0') else 'TORCH_NCCL_ASYNC_ERROR_HANDLING' - ): - '1', + 'TORCH_NCCL_ASYNC_ERROR_HANDLING': '1', } for local_rank in range(nproc): From a6592b2d306185ddf4ceaeacb0cd1d207a79ecd8 Mon Sep 17 00:00:00 2001 From: bowenyang008 Date: Mon, 28 Apr 2025 21:08:52 +0000 Subject: [PATCH 6/9] format --- composer/trainer/_scaler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/composer/trainer/_scaler.py b/composer/trainer/_scaler.py index 98f478588d..b67d489781 100644 --- a/composer/trainer/_scaler.py +++ b/composer/trainer/_scaler.py @@ -5,9 +5,8 @@ from typing import Optional, Union import torch +from torch.amp.grad_scaler import GradScaler, OptState, _refresh_per_optimizer_state from torch.optim import Optimizer -from torch.amp.grad_scaler import _refresh_per_optimizer_state, GradScaler, OptState - from composer.utils import dist From 3a4f6339f43cdfd222a3a8c2c609dd164e53e35f Mon Sep 17 00:00:00 2001 From: bowenyang008 Date: Mon, 28 Apr 2025 23:08:33 +0000 Subject: [PATCH 7/9] recover support for old ckp --- composer/utils/checkpoint.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index c6da987fdb..1910db55ba 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -580,7 +580,34 @@ def dist_cp_load( planner=load_planner, ) return - raise e + + checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata + if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata: + # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility. + # Torch issue: https://github.com/pytorch/pytorch/issues/133923. + # We override the traverse_state_dict so that the load planner could + # use the old way of flattening the state dict + log.debug('Trying to load checkpointing saved before torch 2.4') + + import torch.distributed.checkpoint._nested_dict as nested_dict + import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util + from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0 + + from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse + + nested_dict.traverse_state_dict = backward_compatible_traverse + sharded_tensor_util.traverse_state_dict = backward_compatible_traverse + + dist_cp.load( + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) + # Revert the override + nested_dict.traverse_state_dict = traverse_2_4_0 + sharded_tensor_util.traverse_state_dict = traverse_2_4_0 + else: + raise e def load_sharded_checkpoint( From 71bcd05500a1e1f270daffc107a1e89457199b64 Mon Sep 17 00:00:00 2001 From: bowenyang008 Date: Mon, 28 Apr 2025 23:42:22 +0000 Subject: [PATCH 8/9] version control --- composer/core/state.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 6525494c30..2e7ac11303 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -75,10 +75,6 @@ def fsdp_state_dict_type_context(module: torch.nn.Module, state_dict_type: str = 'full'): """Context manager for materializing or loading an fsdp module's state dict. - .. warning:: - This function is deprecated and will be removed in a future release. - It is maintained for backwards compatibility with tests. - Args: module (torch.nn.Module): The torch module that you want to call `state_dict()` or `load_state_dict()` on. @@ -93,12 +89,6 @@ def fsdp_state_dict_type_context(module: torch.nn.Module, state_dict_type: str = """ # Torch forgot to put ShardedStateDictConfig in torch/distributed/fsdp/__init__.py, so we # have to import it this way. - warnings.warn( - 'fsdp_state_dict_type_context is deprecated and will be removed in a future release', - DeprecationWarning, - stacklevel=2, - ) - from torch.distributed.fsdp.fully_sharded_data_parallel import ShardedStateDictConfig fsdp_state_dict_type = None @@ -139,7 +129,7 @@ def fsdp_get_optim_state_dict( """Materializes a given model's optimizer's state_dict. .. warning:: - This function is deprecated and will be removed in a future release. + This function is deprecated and will be removed in Composer version 0.32. It is maintained for backwards compatibility with tests. Args: @@ -157,7 +147,7 @@ def fsdp_get_optim_state_dict( dict[str, Any]: The state_dict for the given optimizer. """ warnings.warn( - 'fsdp_get_optim_state_dict is deprecated and will be removed in a future release', + 'fsdp_get_optim_state_dict is deprecated and will be removed in Composer version 0.32', DeprecationWarning, stacklevel=2, ) From 6c9b58b065070c2c57783f19181bed2a04d6f6dc Mon Sep 17 00:00:00 2001 From: bowenyang008 Date: Mon, 28 Apr 2025 23:46:32 +0000 Subject: [PATCH 9/9] out of date comment --- composer/utils/checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 1910db55ba..e1949e03af 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -1066,7 +1066,6 @@ def _save_checkpoint( # Only rank 0 saves RNG if dist.get_global_rank() > 0: state_dict.pop('rng') - # For older versions of PyTorch, we don't need this special handling anymore log.debug('State dict created.') dirname = os.path.dirname(save_filename)