diff --git a/composer/checkpoint/load.py b/composer/checkpoint/load.py index 43013fdc77..bd4d2f1013 100644 --- a/composer/checkpoint/load.py +++ b/composer/checkpoint/load.py @@ -15,8 +15,8 @@ import torch import torch.distributed.checkpoint as DCP -from packaging import version from torch import nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim import Optimizer @@ -273,7 +273,7 @@ def _load_sharded_model_checkpoint( model_state_dict = _cast_state_dict_to_precision(model_state_dict, load_options.precision) # TODO: raise warning for unknown or missing keys. log.debug(f'Loading sharded state dict into model.') - if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized(): + if dist.is_initialized(): torch_set_model_state_dict( model, model_state_dict, @@ -395,7 +395,7 @@ def _load_sharded_optim_checkpoint( ) for param_key, param_state_dict in optim_state_dict['state'].items(): optim_state_dict['state'][param_key] = _cast_state_dict_to_precision(param_state_dict, load_options.precision) - if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized(): + if dist.is_initialized(): torch_set_optimizer_state_dict( model, optim, @@ -522,7 +522,7 @@ def torch_set_model_state_dict( cpu_offload (bool): Whether to offload the state dict to CPU before setting. sharded_state_dict (bool): Whether the state dict is sharded or not. """ - if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized(): + if dist.is_initialized(): from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict try: set_model_state_dict( @@ -569,8 +569,7 @@ def torch_set_optimizer_state_dict( cpu_offload (bool): Whether to offload the state dict to CPU before setting. sharded_state_dict (bool): Whether the state dict is sharded or not. """ - if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized(): - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict + if dist.is_initialized(): set_optimizer_state_dict( model=model, optimizers=optim, @@ -626,10 +625,7 @@ def download_and_load_sharded_state_dict( raise log.debug(f'Loading sharded state dict from {load_path} into memory.') - if version.parse(torch.__version__) < version.parse('2.2.0'): - DCP.load_state_dict(state_dict=state_dict, storage_reader=storage_reader, planner=load_planner) - else: - DCP.load(state_dict=state_dict, storage_reader=storage_reader, planner=load_planner) + DCP.load(state_dict=state_dict, storage_reader=storage_reader, planner=load_planner) return state_dict diff --git a/composer/checkpoint/save.py b/composer/checkpoint/save.py index fd5dd4629d..6a536cd2e4 100644 --- a/composer/checkpoint/save.py +++ b/composer/checkpoint/save.py @@ -15,7 +15,6 @@ import torch import torch.distributed.checkpoint as DCP -from packaging import version from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._tensor import DTensor @@ -358,13 +357,7 @@ def _save_sharded_state_dict_to_disk( f'Starting saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}', ) - # For 2.3.0 and above you can use checkpoint_id, but this version works the best for all versions - # of torch (and makes pyright happier) that we support, so we use it for now. - if version.parse(torch.__version__) < version.parse('2.2.0'): - DCP.save_state_dict(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path)) - else: - DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path)) - + DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path)) log.debug( f'Finished saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}', ) diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 68910c070a..b5a4afc60c 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -16,7 +16,6 @@ from composer.core.evaluator import Evaluator import torch -from packaging import version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel @@ -65,7 +64,7 @@ def get_model_state_dict( cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict) log.debug('Extracting model state dict') - if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized(): + if dist.is_initialized(): from torch.distributed.checkpoint import state_dict as DCPSD # Distributed Checkpoint State Dict from torch.distributed.checkpoint.state_dict import StateDictOptions @@ -337,7 +336,7 @@ def get_optim_state_dict( cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict) log.debug('Extracting optim state dict') - if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized(): + if dist.is_initialized(): from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict log.debug('Calling torch get_optimizer_state_dict...') optim_state_dict: dict[str, Any] = get_optimizer_state_dict( diff --git a/composer/core/state.py b/composer/core/state.py index 198aeac89e..f396848805 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -15,7 +15,6 @@ import numpy as np import torch import torch.nn.modules.utils -from packaging import version from torch.amp.grad_scaler import GradScaler from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.checkpoint.state_dict import ( @@ -25,6 +24,7 @@ set_model_state_dict, set_optimizer_state_dict, ) +from torch.distributed.fsdp import FSDPModule from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullOptimStateDictConfig, @@ -601,8 +601,6 @@ def _validate_parallelism_configs(self): # Validate TP config if self.tp_config is not None: warnings.warn('Tensor parallelism (TP) is experimental and may change in future versions.', FutureWarning) - if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'): - raise ValueError('Tensor parallelism (TP) requires torch>=2.3.0.') if self.fsdp_config is None: raise ValueError( 'Tensor parallelism (TP) currently requires FSDP to be enabled. ' @@ -930,13 +928,8 @@ def fsdp_enabled(self): """Indicates if FSDP is enabled.""" for module in self.model.modules(): # FSDP is FSDP1, FSDPModule is FSDP2 - if isinstance(module, FSDP): + if isinstance(module, (FSDP, FSDPModule)): return True - # TODO remove this once we deprecate torch 2.5 - if version.parse(torch.__version__) >= version.parse('2.6.0'): - from torch.distributed.fsdp._fully_shard import FSDPModule - if isinstance(module, FSDPModule): - return True return False @property diff --git a/composer/devices/device_mps.py b/composer/devices/device_mps.py index 515386ca2d..106a963995 100644 --- a/composer/devices/device_mps.py +++ b/composer/devices/device_mps.py @@ -10,7 +10,6 @@ import torch import torch.cuda.amp import torch.utils.data -from packaging import version from composer.devices.device import Device @@ -28,8 +27,6 @@ class DeviceMPS(Device): name = 'mps' def __init__(self) -> None: - if version.parse(torch.__version__) < version.parse('1.12.0'): - raise RuntimeError('Support for MPS device requires torch >= 1.12.') if not torch.backends.mps.is_available(): # type: ignore (version guarded) raise RuntimeError('MPS requires MAC OSX >= 12.3') if not torch.backends.mps.is_built(): # type: ignore (version guarded) diff --git a/composer/distributed/__init__.py b/composer/distributed/__init__.py index 2d6a81563b..ba2a71eacf 100644 --- a/composer/distributed/__init__.py +++ b/composer/distributed/__init__.py @@ -11,6 +11,7 @@ prepare_fsdp_module, prepare_tp_module, ) +from composer.distributed.prepare_distributed import parallelize_composer_model __all__ = [ 'DDPSyncStrategy', @@ -18,4 +19,5 @@ 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module', + 'parallelize_composer_model', ] diff --git a/composer/distributed/dist_strategy.py b/composer/distributed/dist_strategy.py index bbfa46fbc0..ae621491f7 100644 --- a/composer/distributed/dist_strategy.py +++ b/composer/distributed/dist_strategy.py @@ -9,7 +9,6 @@ from typing import Any, Callable, ContextManager, Iterator, Optional, Sequence, Union, cast import torch -from packaging import version from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, apply_activation_checkpointing, @@ -338,9 +337,7 @@ def sync_hook(*args): sharding_strategy = SHARDING_MAP[sharding_map_key] kwargs = {} - if version.parse( - torch.__version__.split('.dev')[0], - ) >= version.parse('2.2.0') and fsdp_config.device_mesh is not None: + if fsdp_config.device_mesh is not None: if fsdp_config.process_group is not None: warnings.warn( 'process_group and device_mesh are set for FSDP, so ignoring device_mesh. Please set process_group to None.', diff --git a/composer/trainer/_patch_pytorch.py b/composer/trainer/_patch_pytorch.py index 78221001aa..ab79896587 100644 --- a/composer/trainer/_patch_pytorch.py +++ b/composer/trainer/_patch_pytorch.py @@ -11,17 +11,10 @@ """PyTorch, especially PyTorch Distributed, monkeypatches.""" import logging -import functools -import contextlib -from dataclasses import asdict -from itertools import chain -from typing import Any, Callable, Iterable, Generator, Optional, Union, cast, no_type_check - +from typing import no_type_check import torch -import torch.nn as nn from packaging import version -from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy from composer.utils import dist @@ -29,1042 +22,23 @@ def patch_unshard_for_automicrobatching(auto_microbatch_size_found=False): """Monkey patches sync hook into unshard when searching during automicrobatching.""" - if version.parse(torch.__version__) >= version.parse('2.3.1'): - from torch.distributed.fsdp._flat_param import FlatParamHandle - if auto_microbatch_size_found: - global original_unshard - FlatParamHandle.unshard = (original_unshard) - else: - FlatParamHandle.unshard = (unshard_with_sync) + from torch.distributed.fsdp._flat_param import FlatParamHandle + if auto_microbatch_size_found: + global original_unshard + FlatParamHandle.unshard = (original_unshard) + else: + FlatParamHandle.unshard = (unshard_with_sync) def patch_pytorch(): """Monkey patches pytorch functions based on pytorch version.""" - if version.parse(torch.__version__) < version.parse('2.2.1'): - # Monkey patch for torch < 2.2.1 ie torch == 2.2.0 - - # Allow 2D HSDP - from torch.distributed.fsdp import _runtime_utils - _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None - - elif version.parse(torch.__version__) < version.parse('2.2.3'): - # Monkey patch for torch < 2.2.3 ie torch == 2.2.1/2.2.2 currently - - # Fix memory leak for FSDP.optim_state_dict_to_load - # https://github.com/pytorch/pytorch/issues/116553 - from torch.distributed.fsdp import _optim_utils - - _optim_utils._shard_orig_param_state = _shard_orig_param_state - - elif version.parse(torch.__version__) < version.parse('2.3.1'): - # Monkey patch for torch < 2.3.1 ie torch == 2.3.0 - - # Monkeypatch _flat_param.py to fix 2D with SHARD_GRAD_OP - # Issue: https://github.com/pytorch/pytorch/issues/123272 - from torch.distributed.fsdp import _flat_param - - _flat_param._same_storage = _same_storage - - # Monkeypatch state_dict to get FQNs correctly. - # Issue: https://github.com/pytorch/pytorch/pull/124698 - from torch.distributed.checkpoint import state_dict - - state_dict.set_model_state_dict = set_model_state_dict - state_dict.set_optimizer_state_dict = set_optimizer_state_dict - # Issue: https://github.com/pytorch/pytorch/issues/122946 - # - PR: https://github.com/pytorch/pytorch/pull/125336 - state_dict._get_fqns = _get_fqns - state_dict._verify_options = _verify_options - state_dict._get_model_state_dict = _get_model_state_dict - state_dict._load_model_state_dict = _load_model_state_dict - - # Monkeypatch for ND child submeshes - # PR: https://github.com/pytorch/pytorch/pull/119752 - from torch.distributed.device_mesh import DeviceMesh, _MeshEnv - - _MeshEnv.create_child_mesh = create_child_mesh - DeviceMesh.__getitem__ = device_mesh__getitem__ - DeviceMesh.__init__ = device_mesh__init__ - - elif version.parse(torch.__version__) < version.parse('2.3.2'): - # Monkey patch for torch < 2.3.2 ie torch == 2.3.1 - - # Issue: https://github.com/pytorch/pytorch/issues/122946 - # - PR: https://github.com/pytorch/pytorch/pull/125336 - from torch.distributed.checkpoint import state_dict - - state_dict._verify_options = _verify_options - state_dict._get_model_state_dict = _get_model_state_dict - state_dict._load_model_state_dict = _load_model_state_dict - - # Monkeypatch for ND child submeshes - # PR: https://github.com/pytorch/pytorch/pull/119752 - from torch.distributed.device_mesh import DeviceMesh, _MeshEnv - - _MeshEnv.create_child_mesh = create_child_mesh - DeviceMesh.__getitem__ = device_mesh__getitem__ - - elif version.parse(torch.__version__) < version.parse('2.4.1'): - # Monkey patch for torch < 2.4.1 ie torch == 2.4.0 - - # No monkeypatches besides unshard (below)! - pass - - elif version.parse(torch.__version__) < version.parse('2.5.1'): - # Monkey patch for torch < 2.5.1 ie torch == 2.5.0 - - # No monkeypatches besides unshard (below)! - pass - - elif version.parse(torch.__version__) < version.parse('2.6.1'): - # Monkey patch for torch < 2.6.1 ie torch == 2.6.0 - - # No monkeypatches besides unshard (below)! - pass - - elif version.parse(torch.__version__) < version.parse('2.7.1'): + if version.parse(torch.__version__) < version.parse('2.7.1'): # Monkey patch for torch < 2.7.1 ie torch == 2.7.0 # No monkeypatches besides unshard (below)! pass -if version.parse(torch.__version__) >= version.parse('2.2.1') and version.parse( - torch.__version__,) < version.parse('2.2.3'): - - from torch.distributed.fsdp._optim_utils import FSDPParamInfo - from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict - - @no_type_check - def _shard_orig_param_state( - fsdp_param_info: FSDPParamInfo, - fqn: str, - optim_state: dict[str, Any], - ) -> dict[str, Any]: - if not optim_state: - return {} - fsdp_state = fsdp_param_info.state - flat_param = fsdp_param_info.handle.flat_param - param_idx = fsdp_param_info.param_indices[fqn] - shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined] - optim_state = _gather_state_dict( - optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device, - ) - if not shard_param_info.in_shard: - return {} - # Flatten and shard the state. - new_optim_state: dict[str, Any] = {} - intra_param_start_idx = shard_param_info.intra_param_start_idx - intra_param_end_idx = shard_param_info.intra_param_end_idx - for state_name, value in optim_state.items(): - if ( - torch.is_tensor(value) - and value.dim() > 0 - and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD - ): - # This clone() is the patch to fix the OOM - # https://github.com/pytorch/pytorch/pull/117261 - value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator] - new_optim_state[state_name] = value - return new_optim_state - - -if version.parse(torch.__version__) >= version.parse('2.3.0') and version.parse( - torch.__version__, -) < version.parse('2.3.2'): - from torch.distributed._tensor import DTensor - - @no_type_check - def _same_storage(a, b): - if isinstance(a, DTensor): - a = a._local_tensor - if isinstance(b, DTensor): - b = b._local_tensor - return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr() - - from torch.distributed.checkpoint.state_dict import (_unflatten_model_state_dict, - gc_context, - _load_optim_state_dict, - _state_dict_fn, - _offload_state_dict_to_cpu, - _verify_state_dict, - StateDictOptions, _StateDictInfo, - FLAT_PARAM, FQNS_T) - from torch.distributed._state_dict_utils import _gather_state_dict - from torch.nn.modules.module import _IncompatibleKeys - from torch.distributed.fsdp import ( - FullOptimStateDictConfig, - FullStateDictConfig, - FullyShardedDataParallel as FSDP, - OptimStateDictConfig, - ShardedOptimStateDictConfig, - ShardedStateDictConfig, - StateDictConfig, - StateDictType, - ) - - from torch.distributed._shard.sharded_tensor import ShardedTensor - - PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] - ValueType = Union[ - PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, 'ValueType'], - ] - DictValueType = dict[str, ValueType] - ListDictValueType = list[DictValueType] - OptimizerStateType = dict[str, Union[DictValueType, ListDictValueType]] - - class _EXTRA_STATE: - pass - - def _iterate_valid_model_state(model): - visited_modules: set[nn.Module] = set() - - def recurse(module: nn.Module, curr_fqn: str) -> Generator: - visited_modules.add(module) - - curr_fqn = f'{curr_fqn}.' if curr_fqn else '' - for name, submodule in module.named_children(): - if submodule in visited_modules: - continue - new_fqn = f'{curr_fqn}{name}' - yield from recurse(submodule, new_fqn) - - for name, obj in chain( - module.named_buffers(recurse=False), module.named_parameters(recurse=False), - ): - new_fqn = f'{curr_fqn}{name}' - yield new_fqn, obj - - if ( - getattr(module.__class__, 'get_extra_state', nn.Module.get_extra_state) - != nn.Module.get_extra_state - ): - new_fqn = f'{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}' - yield new_fqn, _EXTRA_STATE() - - yield from recurse(model, '') - - def _verify_options( - model: nn.Module, - optims: tuple[torch.optim.Optimizer, ...], - optim_only: bool, - *, - submodules: Optional[set[nn.Module]] = None, - options: Optional[StateDictOptions] = None, - ) -> _StateDictInfo: - """Verify the model and options passed by the user and generates _StateDictInfo.""" - if optim_only and not optims: - raise RuntimeError( - 'Optimizers are not passed in but optim_only is set to True.', - ) - - options = options or StateDictOptions() - - fqn_param_mapping: dict[ - Union[str, torch.Tensor], Union[set[str], torch.Tensor], - ] = {} - for name, param in chain(model.named_parameters(), model.named_buffers()): - fqns = _get_fqns(model, name) - fqn_param_mapping[param] = fqns - for fqn in fqns: - fqn_param_mapping[fqn] = param - - all_fqns = set() - for name, _ in _iterate_valid_model_state(model): - fqns = _get_fqns(model, name) - for fqn in fqns: - all_fqns.add(fqn) - - submodule_prefixes: set[str] = set() - if submodules: - submodules = set(submodules) - for name, module in model.named_modules(): - if module not in submodules: - continue - fqns = _get_fqns(model, name) - assert len(fqns) == 1, 'Submodule FQN should only have 1 instance' - submodule_prefixes.update(f'{fqn}.' for fqn in fqns) - - fsdp_modules = FSDP.fsdp_modules(model) - state_dict_config: StateDictConfig - optim_state_dict_config: OptimStateDictConfig - fsdp_context: Callable - if fsdp_modules: - # FSDP API only work if at least one FSDP instance exists. - if options.full_state_dict: - state_dict_config = FullStateDictConfig( - offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload, - ) - optim_state_dict_config = FullOptimStateDictConfig( - offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload, - ) - state_dict_type = StateDictType.FULL_STATE_DICT - else: - state_dict_config = ShardedStateDictConfig( - offload_to_cpu=options.cpu_offload, - ) - optim_state_dict_config = ShardedOptimStateDictConfig( - offload_to_cpu=options.cpu_offload, - ) - state_dict_type = StateDictType.SHARDED_STATE_DICT - - fsdp_context = functools.partial( - FSDP.state_dict_type, - module=model, - state_dict_type=state_dict_type, - state_dict_config=state_dict_config, - optim_state_dict_config=optim_state_dict_config, - ) - else: - fsdp_context = contextlib.nullcontext - - return _StateDictInfo( - **asdict(options), - fqn_param_mapping=fqn_param_mapping, - all_fqns=all_fqns, - submodule_prefixes=submodule_prefixes, - fsdp_context=fsdp_context, - fsdp_modules=cast(list[nn.Module], fsdp_modules), - handle_model=not optim_only, - handle_optim=(len(optims) > 0), - ) - - - def _get_model_state_dict( - model: nn.Module, info: _StateDictInfo, - ) -> dict[str, ValueType]: - if not info.handle_model: - return {} - - with info.fsdp_context(): - state_dict = _state_dict_fn(model, 'state_dict')() - - for key in list(state_dict.keys()): - fqns = _get_fqns(model, key) - assert len(fqns) == 1 - fqn = next(iter(fqns)) - if fqn != key: - # As we only support FSDP, DDP, and TP, the only cases are - # wrapper-based DDP and compiler. Verify if the assumption - # is correct. - def verify(key, fqn) -> bool: - if len(fqn) >= len(key): - return False - fqn_split = fqn.split('.') - key_split = key.split('.') - fqn_idx = 0 - for key_idx, key_name in enumerate(key_split): - if key_name == fqn_split[fqn_idx]: - fqn_idx += 1 - if fqn_idx == len(fqn_split): - return key_idx == len(key_split) - 1 - elif key_name in ('module', '_orig_mod'): - continue - else: - return False - return True - - if not verify(key, fqn): - raise RuntimeError(f'An unexpected key, {key}, exists. FQN is {fqn}') - state_dict[fqn] = state_dict.pop(key) - - if info.submodule_prefixes: - new_state_dict: dict[str, ValueType] = {} - # TODO: make this faster. - for fqn in state_dict.keys(): - for prefix in info.submodule_prefixes: - if not fqn.startswith(prefix): - continue - if info.keep_submodule_prefixes: - new_state_dict[fqn] = state_dict[fqn] - else: - new_fqn = fqn[len(prefix) :] - new_state_dict[new_fqn] = state_dict[fqn] - state_dict = new_state_dict - - if info.ignore_frozen_params: - for key, param in model.named_parameters(): - if param.requires_grad: - continue - fqns = _get_fqns(model, key) - for fqn in fqns: - state_dict.pop(fqn) - - for key, p in list(state_dict.items()): - if torch.is_tensor(p) and p.is_meta: - state_dict.pop(key) - - if info.full_state_dict: - ranks_only = () if not info.cpu_offload else (0,) - return _gather_state_dict( - state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only, - ) - elif info.cpu_offload: - return _offload_state_dict_to_cpu(state_dict) - else: - return state_dict - - def _load_model_state_dict( - model: nn.Module, - state_dict: dict[str, ValueType], - info: _StateDictInfo, - ) -> _IncompatibleKeys: - if not info.handle_model or not state_dict: - return _IncompatibleKeys({}, {}) - - for key, _ in _iterate_valid_model_state(model): - fqns = _get_fqns(model, key) - fqns_with_prefix = _get_fqns( - model, key, skip_ddp_prefix=False, skip_compiler_prefix=False, - ) - for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix): - if fqn != fqn_with_prefix: - state_dict[fqn_with_prefix] = state_dict.pop(fqn) - - with info.fsdp_context(): - return cast( - _IncompatibleKeys, - _state_dict_fn(model, 'load_state_dict')( - state_dict=state_dict, strict=info.strict, - ), - ) - - - @no_type_check - @functools.lru_cache(maxsize=None) - def _get_fqns( - model: nn.Module, - name: str, - skip_ddp_prefix: bool = True, - skip_compiler_prefix: bool = True, - ) -> FQNS_T: - """Used to convert the name of a parameter to the FQNs. - - For FSDP without `use_orig_params`, the name of FlatParameter can be mapped to - multiple original parameters. As a result, the return type of this function - is `set[str]`. - - Args: - module (nn.Module): the root model. - name (str): the name - skip_ddp_prefix (bool): whether to skip DDP's `module` prefix - - Returns: - The canonical FQNs based on the model traversal. - """ - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import _CHECKPOINT_PREFIX - from torch.nn.parallel import DistributedDataParallel as DDP - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE - - # Remove the checkpoint prefix, if it exists. - name = name.replace(_CHECKPOINT_PREFIX, '') - if '.' not in name: - return {name} - - obj_names = name.split('.') - fqn_obj_names = [] - curr_obj = model - for i, curr_obj_name in enumerate(obj_names): - if isinstance(curr_obj, DDP): - assert curr_obj_name == 'module' - curr_obj = curr_obj.module - if not skip_ddp_prefix: - fqn_obj_names.append(curr_obj_name) - elif isinstance(curr_obj, FSDP): - if i < len(obj_names) - 1 and obj_names[i + 1] == FLAT_PARAM: - prefix = '.'.join(fqn_obj_names) - flat_param = getattr(curr_obj, FLAT_PARAM) - if prefix: - prefix = f'{prefix}.' - return {f'{prefix}{fqn}' for fqn in flat_param._fqns} - curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) - if curr_obj_name != FSDP_WRAPPED_MODULE: - fqn_obj_names.append(curr_obj_name) - curr_obj = getattr(curr_obj, curr_obj_name) - elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): - assert curr_obj_name == '_orig_mod' - curr_obj = curr_obj._orig_mod - if not skip_compiler_prefix: - fqn_obj_names.append(curr_obj_name) - else: - # This part is monkey-patched from https://github.com/pytorch/pytorch/pull/125336 - fqn_obj_names.append(curr_obj_name) - if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: - if i != len(obj_names) - 1: - raise RuntimeError('Expect `_extra_state` to be the last obj name') - else: - curr_obj = getattr(curr_obj, curr_obj_name) - - return {'.'.join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, '')} - - def set_model_state_dict( - model: nn.Module, - model_state_dict, - *, - options = None, - ): - """Load the model state_dict. - - The counterpart of ``get_model_state_dict`` to set the state_dict to the - model. See ``set_state_dict`` for the detail usage. - - Args: - model (nn.Module): the nn.Module to the model. - model_state_dict: (dict[str, ValueType]): - the model state_dict to load. If the key of the ``model_state_dict`` - is nn.Module, the key is a submodule of ``model`` and the value should - be the state_dict of the submodule. When loading the state_dict, - the prefix of the submodule will be append to the state_dict. - options (StateDictOptions): the options to control how - model state_dict and optimizer state_dict should be loaded. See - `StateDictOptions` for the details. - - Returns: - ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: - * **missing_keys** is a list of str containing the missing keys - * **unexpected_keys** is a list of str containing the unexpected keys - - :type model_state_dict: typing.dict[str, ValueType] - """ - from torch.distributed.fsdp._runtime_utils import _lazy_init - for module in model.modules(): - if isinstance(module, FullyShardedDataParallel): - _lazy_init(module, module) - model_state_dict = _unflatten_model_state_dict( - model, model_state_dict, - ) - with gc_context(): - info = _verify_options(model, tuple(), optim_only=False, options=options) - - _verify_state_dict(model_state_dict, {}, info) - return _load_model_state_dict(model, model_state_dict, info) - - def set_optimizer_state_dict( - model: nn.Module, - optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], - *, - optim_state_dict, - options = None, - ) -> None: - """Load the optimizers state_dict. - - The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the - optimizers. See ``set_state_dict`` for the detail usage. - - Args: - model (nn.Module): the nn.Module to the model. - optimizers (Union[Optimizer, Iterable[Optimizer]]): - The optimizers that are used to optimize ``model``. - optim_state_dict: OptimizerStateType: - the optimizer state_dict to load. - options (StateDictOptions): the options to control how - model state_dict and optimizer state_dict should be loaded. See - `StateDictOptions` for the details. - - Returns: - None - - :type optim_state_dict: typing.OptimizerStateType - """ - from torch.distributed.fsdp._runtime_utils import _lazy_init - for module in model.modules(): - if isinstance(module, FullyShardedDataParallel): - _lazy_init(module, module) - with gc_context(): - optimizers = ( - (optimizers,) - if isinstance(optimizers, torch.optim.Optimizer) - else tuple(optimizers) - ) - info = _verify_options(model, optimizers, optim_only=True, options=options) - - _verify_state_dict({}, optim_state_dict, info) - _load_optim_state_dict(model, optimizers, optim_state_dict, info) - - - # torch2.3 patch to fix https://github.com/pytorch/pytorch/issues/125740 - from torch.distributed.checkpoint.default_planner import ( - create_default_global_save_plan, - DefaultSavePlanner, - _validate_global_plan, - ) - import dataclasses - from collections import defaultdict, ChainMap - - from torch.distributed.checkpoint.planner import SavePlan, WriteItem - from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata - - def dedup_save_plans(all_plans: list[SavePlan]) -> list[SavePlan]: # noqa: D103 - write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set) - write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {} - for plan_idx, plan in enumerate(all_plans): - for write_item in plan.items: - # map each write item to its plan - write_item_to_plan_indices[write_item.index].add(plan_idx) - write_item_idx_to_write_item[write_item.index] = write_item - - # put item in the plan with the smallest size and remove it from the other plan_indices - to_remove: list[set] = [set() for _ in range(len(all_plans))] - plan_to_size = [0] * len(all_plans) - for write_item_idx, plan_indices in write_item_to_plan_indices.items(): - # this line is the fix, to keep the duplicated tensors on the same rank - select_plan_idx = min(plan_indices, key=lambda plan_idx: plan_idx) - - write_item = write_item_idx_to_write_item[write_item_idx] - # essentially ignores the storage size of anything that is not a tensor, since - # we don't know how much storage they represent - plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1 - - plan_indices.remove(select_plan_idx) - for plan_idx in plan_indices: - to_remove[plan_idx].add(write_item_idx) - - for plan_idx, remove_set in enumerate(to_remove): - new_items = [ - write_item - for write_item in all_plans[plan_idx].items - if write_item.index not in remove_set - ] - all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) - - return all_plans - - class SavePlannerWithDedupFix(DefaultSavePlanner): # noqa: D101 - def create_global_plan( - self, all_plans: list[SavePlan], - ) -> tuple[list[SavePlan], Metadata]: - all_plans = dedup_save_plans(all_plans) - - global_plan, metadata = create_default_global_save_plan(all_plans) - - if self.flatten_state_dict: - # | does not work for Python 3.8 or older version. - # merged_mappings = reduce( - # lambda x, y: x | y, (p.planner_data for p in global_plan) - # ) - planner_data_dict = [p.planner_data for p in global_plan] - merged_mappings = dict(ChainMap(*planner_data_dict)) - metadata = dataclasses.replace(metadata, planner_data=merged_mappings) - - if not _validate_global_plan(global_plan, metadata): - raise ValueError('Failed to validate global plan') - - self.global_plan = global_plan - self.metadata = metadata - - return self.global_plan, self.metadata - - # DeviceMesh monkeypatch slightly changes in PyTorch 2.3.1 - if version.parse(torch.__version__) < version.parse('2.3.1'): - from torch.utils._typing_utils import not_none - from torch.distributed.device_mesh import DeviceMesh - - def create_child_mesh( - self, - device_mesh, - mesh_dim_names: tuple[str], - ): - """Monkeypatch create_child_mesh to nightly version.""" - # swap the current dim to the last dim then reshape to flatten out other - # dims, so we can just extract the list of ranks which contains cur_rank. - mesh_dims = [ - not_none(device_mesh.mesh_dim_names).index(mesh_dim_name) - for mesh_dim_name in mesh_dim_names - ] - cur_rank = device_mesh.get_rank() - mesh = device_mesh.mesh - all_mesh_dims = list(range(mesh.ndim)) - for mesh_dim in mesh_dims: - # remove not pop b/c we want the value of the ind removed not it's position in the list - # because this list dynamically changes. - all_mesh_dims.remove(mesh_dim) - - mesh_sizes = [device_mesh.mesh.size(mesh_dim) for mesh_dim in mesh_dims] - - pg_ranks_by_dim = device_mesh.mesh.permute( - *all_mesh_dims, *mesh_dims, - ).reshape(-1, *mesh_sizes) - - for mesh_nd in pg_ranks_by_dim: - if cur_rank in mesh_nd: - sub_mesh = DeviceMesh( - device_mesh.device_type, - mesh_nd, - mesh_dim_names=mesh_dim_names, - ) - res_sub_mesh = sub_mesh - - res_sub_mesh._dim_group_infos = [ # type: ignore - device_mesh._dim_group_infos[mesh_dim] for mesh_dim in mesh_dims - ] - - # Assign the current DeviceMesh as the parent of the child DeviceMesh. - self.child_to_parent_mapping[res_sub_mesh] = device_mesh # type: ignore - return res_sub_mesh # type: ignore - - from torch.distributed.device_mesh import _mesh_resources - - def device_mesh__init__( - self, - device_type: str, - mesh, - *, - mesh_dim_names: Optional[tuple[str, ...]] = None, - ) -> None: - """Monkeypatch device mesh __init__ to nightly version.""" - self.device_type = device_type - if isinstance(mesh, torch.Tensor) and mesh.device.type != 'cpu': - raise ValueError(f'`mesh` must be a CPU tensor, got {mesh}') - self.mesh = ( - mesh.detach().cpu() - if isinstance(mesh, torch.Tensor) - else torch.tensor(mesh, dtype=torch.int) - ) - self.mesh_dim_names = mesh_dim_names - - # private field to pre-generate DeviceMesh's hash - self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) - self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self))) - self._parent_mesh = _mesh_resources.get_parent_mesh(self) - - # Skip process group initialization if xla device. - # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. - if device_type != 'xla': - # always try to create default (world) pg, even if it is not initialized - # already. The world pg is used for device mesh identity (rank) on each - # process (we need to know if the current global rank is in the mesh or not). - self._get_or_create_default_group() - if not self._parent_mesh: - self._init_process_groups() - - def device_mesh__getitem__(self, mesh_dim_names: Union[str, tuple[str]]) -> 'DeviceMesh': - """Monkeypatch device_mesh __getitem__ to nightly version. - - Slice the current DeviceMesh based on the mesh_dim_name given to create a child - DeviceMesh. - - Args: - mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh - to create a child DeviceMesh for. - - Returns: - A :class:`DeviceMesh` object - - The following program runs on each process/rank in an SPMD manner. In this example, we have 2 - hosts with 4 GPUs each. - Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]). - Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]). - Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]). - Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]). - Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]). - Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). - - Example:: - >>> # xdoctest: +SKIP("no rank") - >>> from torch.distributed.device_mesh import DeviceMesh - >>> - >>> # Initialize device mesh as (2, 4) to represent the topology - >>> # of cross-host(dim 0), and within-host (dim 1). - >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) - """ - if not self.mesh_dim_names: - raise RuntimeError('Cannot slice a DeviceMesh without mesh_dim_names.') - - mesh_dim_names = ( - (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names - ) - - error_msg = ( - f'Invalid mesh_dim_name {mesh_dim_names} specified. ' - f'Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}.' - ) - - # When the dimension slicing out is equal to the mesh dimensions of the current DeviceMesh, - # we simply return self if the given slicing is valid. - if mesh_dim_names == self.mesh_dim_names: - return self - # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names - # of the current DeviceMesh. - elif len(mesh_dim_names) < len(self.mesh_dim_names): - outermost_dim_name = mesh_dim_names[0] - if outermost_dim_name not in self.mesh_dim_names: - raise ValueError(error_msg) - outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) - for i, j in zip( - mesh_dim_names, - self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], - ): - if i != j: - raise ValueError(error_msg) - else: - raise ValueError(error_msg) - - submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) - return submesh - - else: - from torch.utils._typing_utils import not_none - from torch.distributed.device_mesh import DeviceMesh, _mesh_resources - - def create_child_mesh( - self, parent_mesh: 'DeviceMesh', submesh_dim_names: tuple[str, ...], - ) -> 'DeviceMesh': - """Monkeypatch create_child_mesh to nightly version.""" - # submesh_dims are the mesh dimension of the submesh in the parent mesh. - submesh_dims = [ - not_none(parent_mesh.mesh_dim_names).index(mesh_dim_name) - for mesh_dim_name in submesh_dim_names - ] - submesh_dim_sizes = [ - parent_mesh.mesh.size(mesh_dim) for mesh_dim in submesh_dims - ] - - mesh_dims_remained = list(range(parent_mesh.mesh.ndim)) - for submesh_dim in submesh_dims: - mesh_dims_remained.remove(submesh_dim) - - # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *sub_mesh_dims] - # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with - # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. - pg_ranks_by_dim = parent_mesh.mesh.permute( - *mesh_dims_remained, *submesh_dims, - ).reshape(-1, *submesh_dim_sizes) - - cur_rank = parent_mesh.get_rank() - for mesh_nd in pg_ranks_by_dim: - submesh = DeviceMesh( - parent_mesh.device_type, - mesh_nd, - mesh_dim_names=submesh_dim_names, - _init_backend=False, - ) - if cur_rank in mesh_nd: - res_submesh = submesh - - res_submesh._parent_mesh = parent_mesh # type: ignore - res_submesh._dim_group_infos = [ # type: ignore - parent_mesh._dim_group_infos[mesh_dim] for mesh_dim in submesh_dims # type: ignore - ] - self.child_to_parent_mapping[res_submesh] = parent_mesh # type: ignore - - return res_submesh # type: ignore - - def device_mesh__getitem__( - self, mesh_dim_names: Union[str, tuple[str, ...]], - ) -> 'DeviceMesh': - """Monkeypatch device_mesh __getitem__ to nightly version. - - Slice the current DeviceMesh based on the mesh_dim_name given to create a child - DeviceMesh. - - Args: - mesh_dim_name (Union[str, Tuple[str]]): the name or the tuple of names of the - mesh dimension of the parent DeviceMesh to create the child DeviceMesh for. - - Returns: - A :class:`DeviceMesh` object - - The following program runs on each process/rank in an SPMD manner. In this example, we have 2 - hosts with 4 GPUs each. - Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]). - Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]). - Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]). - Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]). - Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]). - Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). - - Example:: - >>> # xdoctest: +SKIP("no rank") - >>> from torch.distributed.device_mesh import DeviceMesh - >>> - >>> # Initialize device mesh as (2, 4) to represent the topology - >>> # of cross-host(dim 0), and within-host (dim 1). - >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) - """ - if not self.mesh_dim_names: - raise RuntimeError('Cannot slice a DeviceMesh without mesh_dim_names!') - - mesh_dim_names = ( - (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names - ) - - error_msg = ( - f'Invalid mesh_dim_name {mesh_dim_names} specified. ' - f'Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}.' - ) - - if mesh_dim_names == self.mesh_dim_names: - return self - elif len(mesh_dim_names) > len(self.mesh_dim_names) or not all( - mesh_dim_name in self.mesh_dim_names for mesh_dim_name in mesh_dim_names - ): - raise KeyError(error_msg) - # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names - # of the current DeviceMesh. - else: - outermost_dim_name = mesh_dim_names[0] - outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) - for i, j in zip( - mesh_dim_names, - self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], - ): - if i != j: - raise KeyError(error_msg) - - submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) - return submesh - - - # Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks - from torch.distributed.fsdp._flat_param import FlatParamHandle - original_unshard = FlatParamHandle.unshard - - @no_type_check - def unshard_with_sync(self): - """Run the unshard logic, but with a sync after a :meth:`_alloc_padded_unsharded_flat_param`. - - This prevents deadlocks when some ranks OOM after the alloc call and others do not. - This is a patched method from pytorch, meant to be called when automicrobatching - turns on hooks in its search process for the optimal non-OOMing microbatch size. - This includes all-gathering the flat parameter - and switching to using the unsharded flat parameter. If the handle does - not need unsharding, then this only switches to using the unsharded - flat parameter. For ``NO_SHARD``, this is a no-op. - If FSDP is in :meth:`summon_full_params` and the handle uses parameter - mixed precision, then the parameter is forced to full precision. - """ - if not self.needs_unshard(): - # Even when not needing an unshard, we should switch to using - # the unsharded flat parameter - unsharded_flat_param = ( - self._get_padded_unsharded_flat_param() - if self.uses_sharded_strategy - else self.flat_param - ) - self._use_unsharded_flat_param(unsharded_flat_param) - return - unsharded_flat_param = self._alloc_padded_unsharded_flat_param() - - # Check if any other rank hit an OOM - found_cuda_oom_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True) - - dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') - found_cuda_oom = found_cuda_oom_tensor.item() - # Signal current rank is still in batch - all_ranks_finished_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True) - - dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') - - if found_cuda_oom == 1: - raise RuntimeError('CUDA out of memory encountered on a different rank') - padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) - self._use_unsharded_flat_param(padded_unsharded_flat_param) - - -if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse( - torch.__version__, -) < version.parse('2.4.2'): - # PyTorch issue: https://github.com/pytorch/pytorch/issues/133923 - from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE - from typing import Mapping, Collection - PATH_ITEM = Union[str, int] - OBJ_PATH = tuple[PATH_ITEM, ...] - STATE_DICT_ITEM = object - - def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: - return isinstance(value, torch.Tensor) - - # Override the traverse_state_dict to address issue https://github.com/pytorch/pytorch/issues/133923 - # Torch2.4 changed this function for save_planner and load_planner to flatten the state dict. - # It broke backward compatibility. New load_planner can't load checkpointing saved by old save_planner. - # 2.3. vs 2.4 diff: https://github.com/pytorch/pytorch/commit/6f1e3a6bf73327a351dc8a8c08635bd727b3134f - def traverse_state_dict( - state_dict: STATE_DICT_TYPE, - visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], - keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, - ) -> None: - """Invoke ``visitor`` for each value recursively in ``state_dict``. - - Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates - to false for all elements. - By default, all collections with at least one ``torch.Tensor`` element are traversed. - Visitor takes a path argument that is a tuple of the keys used to reach it. - """ - # a value is terminal if it has no other containers values inside it - def _is_terminal(value: STATE_DICT_ITEM) -> bool: - values: Collection[STATE_DICT_ITEM] - if isinstance(value, Mapping): - values = value.values() - elif isinstance(value, list): - values = value - else: - return True - - for entry in values: - if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): - return False - if keep_traversing is not None and keep_traversing(entry): # type: ignore - return False - return True - - def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: - if _is_terminal(value): - visitor(path, value) - elif isinstance(value, Mapping): - for k, v in value.items(): - _traverse_obj(path + (str(k),), v) - elif isinstance(value, list): - for i, v in enumerate(value): - _traverse_obj(path + (i,), v) - - for key, value in state_dict.items(): - _traverse_obj((str(key),), value) - - # Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks - from torch.distributed.fsdp._flat_param import FlatParamHandle - original_unshard = FlatParamHandle.unshard - - @no_type_check - def unshard_with_sync(self): - """Run the unshard logic, but with a sync after a :meth:`_alloc_padded_unsharded_flat_param`. - - This prevents deadlocks when some ranks OOM after the alloc call and others do not. - This is a patched method from pytorch, meant to be called when automicrobatching - turns on hooks in its search process for the optimal non-OOMing microbatch size. - This includes all-gathering the flat parameter - and switching to using the unsharded flat parameter. If the handle does - not need unsharding, then this only switches to using the unsharded - flat parameter. For ``NO_SHARD``, this is a no-op. - If FSDP is in :meth:`summon_full_params` and the handle uses parameter - mixed precision, then the parameter is forced to full precision. - """ - if not self.needs_unshard(): - # Even when not needing an unshard, we should switch to using - # the unsharded flat parameter - unsharded_flat_param = ( - self._get_padded_unsharded_flat_param() - if self.uses_sharded_strategy - else self.flat_param - ) - self._use_unsharded_flat_param(unsharded_flat_param) - return - unsharded_flat_param = self._alloc_padded_unsharded_flat_param() - - # Check if any other rank hit an OOM - found_cuda_oom_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True) - - dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') - found_cuda_oom = found_cuda_oom_tensor.item() - # Signal current rank is still in batch - all_ranks_finished_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True) - - dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') - - if found_cuda_oom == 1: - raise RuntimeError('CUDA out of memory encountered on a different rank') - padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) - self._use_unsharded_flat_param(padded_unsharded_flat_param) -if version.parse(torch.__version__) >= version.parse('2.5.0') and version.parse( +if version.parse(torch.__version__) >= version.parse('2.6.0') and version.parse( torch.__version__, ) < version.parse('2.7.1'): diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 940a4642cd..c0154d3fed 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -38,7 +38,6 @@ import torch.distributed import torch.nn as nn import torch.utils.data -from packaging import version from torch._dynamo import OptimizedModule from torch.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state from torch.distributed.fsdp import FullyShardedDataParallel @@ -75,6 +74,7 @@ from composer.distributed import ( DDPSyncStrategy, ddp_sync_context, + parallelize_composer_model, prepare_ddp_module, prepare_fsdp_module, prepare_tp_module, @@ -1859,7 +1859,6 @@ def _wrap_model_for_distributed( self.state.seed, ) case 2: - from composer.distributed.prepare_distributed import parallelize_composer_model parallelize_composer_model( model, optimizers, @@ -2174,18 +2173,6 @@ def fit( # different units than ``max_duration`` self.state.max_duration = duration + self.state.timestamp.get(duration.unit) - # Raise error if callig fit with SGD - if ( - type(self.state.optimizers[0]) == torch.optim.SGD and - version.parse(torch.__version__) >= version.parse('2.4.0') and - version.parse(torch.__version__) < version.parse('2.5.0') - ): - raise ValueError( - 'PyTorch 2.4 breaks (distributed) checkpointing with SGD. ' - 'Please use a different optimizer, e.g. composer.optim.DecoupledSGDW, ' - 'instead. See https://github.com/pytorch/pytorch/issues/133415 for further information.', - ) - if self.state.max_duration is None: _raise_missing_argument_exception('max_duration') assert self.state.max_duration is not None diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index cfaec61cb4..54f0101723 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -596,34 +596,7 @@ def dist_cp_load( planner=load_planner, ) return - - checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata - if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata: - # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility. - # Torch issue: https://github.com/pytorch/pytorch/issues/133923. - # We override the traverse_state_dict so that the load planner could - # use the old way of flattening the state dict - log.debug('Trying to load checkpointing saved before torch 2.4') - - import torch.distributed.checkpoint._nested_dict as nested_dict - import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util - from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0 - - from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse - - nested_dict.traverse_state_dict = backward_compatible_traverse - sharded_tensor_util.traverse_state_dict = backward_compatible_traverse - - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - # Revert the override - nested_dict.traverse_state_dict = traverse_2_4_0 - sharded_tensor_util.traverse_state_dict = traverse_2_4_0 - else: - raise e + raise e def load_sharded_checkpoint( diff --git a/pyproject.toml b/pyproject.toml index 1404d19839..88d05f3f45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -187,6 +187,8 @@ filterwarnings = [ '''ignore:.*`_get_pg_default_device` will be deprecated.*:UserWarning''', # Ignore ROCMMonitor warnings '''ignore:.*ROCMMonitor.__del__:pytest.PytestUnraisableExceptionWarning''', + # Ignore FSDP2Config deprecation warning + '''ignore:.*FSDP2 Config/APIs are experimental*:UserWarning''', ] # Coverage diff --git a/tests/algorithms/test_required_on_load.py b/tests/algorithms/test_required_on_load.py index 98e3462163..566167df31 100644 --- a/tests/algorithms/test_required_on_load.py +++ b/tests/algorithms/test_required_on_load.py @@ -8,7 +8,6 @@ import pytest import torch -from packaging import version from composer import Trainer, algorithms from composer.callbacks import CheckpointSaver @@ -17,7 +16,6 @@ from composer.core import TimeUnit # noqa: F401 # type: ignore imports used in `eval(representation)` from composer.core import Algorithm from composer.models import ComposerClassifier, ComposerModel -from composer.utils import dist from tests.common import ConvModel, SimpleConvModel, composer_resnet @@ -176,36 +174,16 @@ def test_autoload( context = pytest.warns(UserWarning, match='Automatically adding required_on_load algorithm*') # Excluding some algorithms leads to errors when loading elif exclude: - if version.parse(torch.__version__) >= version.parse('2.4.0'): - if algo_name in [ - 'BlurPool', - 'Factorize', - 'GatedLinearUnits', - 'GhostBatchNorm', - 'SqueezeExcite', - ]: - context = pytest.raises(KeyError) # Optimizer loading is strict - elif algo_name == 'Alibi': - context = pytest.raises(RuntimeError) # Alibi has shape issues - elif version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized(): - if algo_name in [ - 'Alibi', - 'BlurPool', - 'Factorize', - 'GatedLinearUnits', - 'GhostBatchNorm', - 'SqueezeExcite', - ]: - context = pytest.raises(KeyError) # Optimizer loading is strict - else: - if algo_name in ['Factorize', 'SqueezeExcite']: - context = pytest.raises( - ValueError, - match= - "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", - ) - elif algo_name == 'Alibi': - context = pytest.raises(RuntimeError) + if algo_name in [ + 'BlurPool', + 'Factorize', + 'GatedLinearUnits', + 'GhostBatchNorm', + 'SqueezeExcite', + ]: + context = pytest.raises(KeyError) # Optimizer loading is strict + elif algo_name == 'Alibi': + context = pytest.raises(RuntimeError) # Alibi has shape issues with context: trainer2 = Trainer( diff --git a/tests/callbacks/test_memory_snapshot.py b/tests/callbacks/test_memory_snapshot.py index 4f12b586fe..66d8827159 100644 --- a/tests/callbacks/test_memory_snapshot.py +++ b/tests/callbacks/test_memory_snapshot.py @@ -4,8 +4,6 @@ import pathlib import pytest -import torch -from packaging import version from torch.utils.data import DataLoader from composer import State, Trainer @@ -15,10 +13,6 @@ from tests.common import RandomClassificationDataset, SimpleModel -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.1.0'), - reason='OOM Observer requires PyTorch 2.1 or higher', -) def test_memory_snapshot_warnings_on_cpu_models(): with pytest.warns(UserWarning): Trainer( @@ -42,10 +36,6 @@ def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Pa @pytest.mark.gpu @pytest.mark.parametrize('interval', ['1ba']) -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.1.0'), - reason='OOM Observer requires PyTorch 2.1 or higher', -) def test_memory_snapshot(interval: str): # Construct the callbacks skip_batches = 0 diff --git a/tests/callbacks/test_oom_observer.py b/tests/callbacks/test_oom_observer.py index c3b2dcc0dd..d7248db903 100644 --- a/tests/callbacks/test_oom_observer.py +++ b/tests/callbacks/test_oom_observer.py @@ -5,7 +5,6 @@ import pytest import torch -from packaging import version from torch.utils.data import DataLoader from composer import State, Trainer @@ -15,10 +14,6 @@ from tests.common import RandomClassificationDataset, SimpleModel -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.1.0'), - reason='OOM Observer requires PyTorch 2.1 or higher', -) def test_oom_observer_warnings_on_cpu_models(): ob = OOMObserver() with pytest.warns(UserWarning): @@ -43,10 +38,6 @@ def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Pa @pytest.mark.gpu -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.1.0'), - reason='OOM Observer requires PyTorch 2.1 or higher', -) def test_oom_observer(): # Construct the callbacks oom_observer = OOMObserver() @@ -71,10 +62,6 @@ def test_oom_observer(): @pytest.mark.gpu -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.1.0'), - reason='OOM Observer requires PyTorch 2.1 or higher', -) def test_oom_observer_with_memory_snapshot(): # Construct the callbacks oom_observer = OOMObserver() diff --git a/tests/callbacks/test_optimizer_monitor.py b/tests/callbacks/test_optimizer_monitor.py index 523408c09c..a9ed5e1898 100644 --- a/tests/callbacks/test_optimizer_monitor.py +++ b/tests/callbacks/test_optimizer_monitor.py @@ -2,8 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -import torch -from packaging import version from torch.utils.data import DataLoader from composer.callbacks import OptimizerMonitor @@ -53,10 +51,6 @@ def test_optimizer_monitor(log_optimizer_metrics: bool, batch_log_interval: int) @device('gpu') @world_size(1, 2) -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('1.13.0'), - reason='requires PyTorch 1.13 or higher', -) @pytest.mark.parametrize('use_orig_params', [True, False]) def test_fsdp_optimizer_monitor(device, world_size, use_orig_params): # Construct the callback diff --git a/tests/checkpoint/helpers.py b/tests/checkpoint/helpers.py index 52838c9aa5..b9e6a331fe 100644 --- a/tests/checkpoint/helpers.py +++ b/tests/checkpoint/helpers.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock import torch -from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import CPUOffload from torch.optim import Adam @@ -62,10 +61,7 @@ def init_state( if include_algorithms: kwargs['algorithms'] = [SWA()] if use_grad_scaler: - if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import GradScaler - else: - from torch.cuda.amp.grad_scaler import GradScaler + from torch.amp.grad_scaler import GradScaler kwargs['scaler'] = GradScaler() state = State( diff --git a/tests/checkpoint/test_save.py b/tests/checkpoint/test_save.py index 31c95afbe4..124913222c 100644 --- a/tests/checkpoint/test_save.py +++ b/tests/checkpoint/test_save.py @@ -11,7 +11,6 @@ import pytest import torch import torch.distributed.checkpoint as DCP -from packaging import version from composer.checkpoint.save import ( save_checkpoint_to_disk, @@ -116,10 +115,7 @@ def test_save_optim_to_disk(world_size: int, tmp_path: str, sharded_optimizer: b if sharded_checkpoint: expected_file_path = os.path.join(destination_dir) - if version.parse(torch.__version__) < version.parse('2.2.0'): - DCP.load_state_dict(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) - else: - DCP.load(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) + DCP.load(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) else: if dist.get_global_rank() == 0: expected_file_path = destination_dir @@ -154,10 +150,7 @@ def test_save_model_to_disk(world_size: int, tmp_path: str, sharded_model: bool, if sharded_checkpoint: expected_file_path = destination_dir - if version.parse(torch.__version__) < version.parse('2.2.0'): - DCP.load_state_dict(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) - else: - DCP.load(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) + DCP.load(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) else: if dist.get_global_rank() == 0: expected_file_path = destination_dir @@ -195,13 +188,7 @@ def test_save_full_state_dict_to_disk(world_size: int, tmp_path: str, sharded_mo 'tensor_type', [ 'sharded_tensor', - pytest.param( - 'dtensor', - marks=pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.2.0'), - reason='Requires torch>=2.2.0 for dtensor', - ), - ), + 'dtensor', ], ) def test_save_sharded_state_dict_to_disk(world_size: int, tmp_path: str, tensor_type: str): @@ -217,8 +204,5 @@ def test_save_sharded_state_dict_to_disk(world_size: int, tmp_path: str, tensor_ assert path_saved == f'{destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}' assert path_saved is not None load_path = str(Path(path_saved).parent) - if version.parse(torch.__version__) < version.parse('2.2.0'): - DCP.load_state_dict(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path)) - else: - DCP.load(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path)) + DCP.load(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path)) deep_compare(state_dict, loaded_in_state_dict) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 4afb90846f..1b649bd20a 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -6,7 +6,6 @@ import pytest import torch -from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from composer.checkpoint import ( @@ -71,8 +70,6 @@ def test_get_model_state_dict_ignore(use_composer_model: bool): @pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_model_state_dict_full_for_sharded_model(world_size, tensor_type, use_composer_model: bool): - if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): - pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') if use_composer_model: model = SimpleComposerMLP(num_features=16, device='cuda') else: @@ -110,9 +107,6 @@ def test_get_model_state_dict_full_for_sharded_model(world_size, tensor_type, us @pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_model_state_dict_sharded(world_size, tensor_type, use_composer_model: bool): - if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): - pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') - if use_composer_model: model = SimpleComposerMLP(num_features=16, device='cuda') else: @@ -181,8 +175,6 @@ def test_get_model_state_dict_precision_sharded_model( precision: str, use_composer_model: bool, ): - if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): - pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') if use_composer_model: model = SimpleComposerMLP(num_features=8, device='cuda') else: @@ -308,9 +300,6 @@ def test_get_optim_state_dict_precision_unsharded_model(precision: str, use_comp @pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_optim_dict_full_for_sharded_model(world_size, tensor_type, use_composer_model: bool): - if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): - pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') - model, optimizer = init_model_and_optimizer( use_composer_model=use_composer_model, take_step=True, @@ -339,9 +328,6 @@ def test_get_optim_dict_full_for_sharded_model(world_size, tensor_type, use_comp @pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_optim_dict_sharded_for_sharded_model(world_size, tensor_type, use_composer_model: bool): - if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): - pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') - model, optimizer = init_model_and_optimizer( use_composer_model=use_composer_model, take_step=True, @@ -409,8 +395,6 @@ def test_get_metadata_unsharded_model(model_type: str): @pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) @pytest.mark.parametrize('model_type', ['composer', 'hf', 'nn.module']) def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_size: int): - if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): - pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') if model_type == 'composer': model = SimpleComposerMLP(num_features=8, device='cuda') expected_model_name = 'SimpleComposerMLP' diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py index e4aa8620f1..47c91b8bcb 100644 --- a/tests/models/test_hf_model.py +++ b/tests/models/test_hf_model.py @@ -12,7 +12,6 @@ import pytest import torch -from packaging import version from torch.utils.data import DataLoader from torchmetrics import Metric from torchmetrics.classification import MulticlassAccuracy @@ -502,16 +501,6 @@ def get_lm_trainer( should_save_peft_only=should_save_peft_only, ) - # On torch 2.0, fsdp wrapped modules can not have both frozen and unfrozen params. - # On 2.1+, if you have use_orig_params=True, they can. So we need a special case for the tests here. - if version.parse(torch.__version__) < version.parse('2.1.0') and peft_config is not None: - for name, module in model.named_modules(): - if 'lora' in name.lower() and 'default' in name.lower(): - has_parameters = any(True for _ in module.parameters()) - has_buffers = any(True for _ in module.buffers()) - if has_parameters or has_buffers: - module._fsdp_wrap = True # type: ignore - vocab_size = hf_model.config.vocab_size sequence_length = 4 size = 4 @@ -1324,7 +1313,6 @@ def test_peft_init(peft_type: str, task_type: str, tiny_gpt2_model, gpt2_peft_co assert hf_model.model.config == original_model.config -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+') def test_peft_init_errors(tiny_gpt2_model, gpt2_peft_config): pytest.importorskip('peft') peft_config = copy.deepcopy(gpt2_peft_config) @@ -1334,7 +1322,6 @@ def test_peft_init_errors(tiny_gpt2_model, gpt2_peft_config): _ = HuggingFaceModel(tiny_gpt2_model, peft_config=peft_config) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+') def test_peft_init_not_installed(tiny_gpt2_model, gpt2_peft_config): pytest.importorskip('peft') @@ -1344,7 +1331,6 @@ def test_peft_init_not_installed(tiny_gpt2_model, gpt2_peft_config): _ = HuggingFaceModel(tiny_gpt2_model, peft_config=gpt2_peft_config) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+') @pytest.mark.parametrize('should_save_peft_only', [True, False]) def test_peft_trains_and_loads(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path, should_save_peft_only): pytest.importorskip('peft') @@ -1375,7 +1361,6 @@ def test_peft_trains_and_loads(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_c torch.testing.assert_close(p1, p2) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+') @pytest.mark.parametrize( 'model,tokenizer,peft_config', [ @@ -1398,7 +1383,6 @@ def test_peft_generate(model, tokenizer, peft_config): hf_model.generate(**input_dict, max_new_tokens=5, pad_token_id=tokenizer.pad_token_id) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+') def test_peft_metadata(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config): pytest.importorskip('peft') @@ -1411,7 +1395,6 @@ def test_peft_metadata(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config): assert loaded_peft_config == gpt2_peft_config -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+') @pytest.mark.parametrize('should_save_peft_only', [True, False]) def test_peft_write_hf_from_composer( tiny_gpt2_model, @@ -1536,7 +1519,6 @@ def test_peft_fsdp_trains( assert not all('lora' in k for k in loaded_ckpt_1['state']['model'].keys()) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+') def test_filtered_state_dict(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path): pytest.importorskip('peft') diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 5014c9faf4..734e84a174 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -8,7 +8,6 @@ import pytest import torch -from packaging import version from torch.profiler.profiler import ProfilerAction as TorchProfilerAction from composer.core import Engine, Event, State, Timestamp @@ -145,9 +144,6 @@ def test_profiler_error_message( @pytest.mark.gpu def test_memory_timeline(tmp_path: pathlib.Path) -> None: - if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): - # memory timeline is supported after PyTorch 2.1.0. - return import torch.profiler._memory_profiler as _memory_profiler model = torch.nn.Sequential( diff --git a/tests/test_full_nlp.py b/tests/test_full_nlp.py index 0670d2988b..cafdfb6b82 100644 --- a/tests/test_full_nlp.py +++ b/tests/test_full_nlp.py @@ -6,7 +6,6 @@ import pytest import torch -from packaging import version from torch.utils.data import DataLoader from torchmetrics.classification import MulticlassAccuracy from transformers import BertConfig, BertForMaskedLM, BertForSequenceClassification @@ -255,9 +254,6 @@ def test_full_nlp_pipeline( pytest.importorskip('libcloud') pytest.importorskip('transformers') - if onnx_opset_version == None and version.parse(torch.__version__) < version.parse('1.13'): - pytest.skip("Don't test prior PyTorch version's default Opset version.") - algorithms = [algorithm() for algorithm in algorithms] device = get_device(device) config = None diff --git a/tests/trainer/fsdp2_context.py b/tests/trainer/fsdp2_context.py deleted file mode 100644 index 6fa8f6e8e3..0000000000 --- a/tests/trainer/fsdp2_context.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -from typing import Callable, Optional - -import pytest -import torch -from packaging import version - -SKIP_TEST = version.parse(torch.__version__) < version.parse('2.6.0') -if not SKIP_TEST: - # TODO (FSDP2) move this to top once we deprecate torch 2.5 - from composer.distributed import activation_checkpointing, fsdp2, fsdp2_utils, prepare_distributed - apply_ac = activation_checkpointing.apply_ac - parallelize_model = prepare_distributed.parallelize_model - legalize_param_sharing_between_modules = fsdp2_utils.legalize_param_sharing_between_modules - get_standalone_and_tied_modules = fsdp2_utils.get_standalone_and_tied_modules - _recursive_apply_fully_shard = fsdp2._recursive_apply_fully_shard - _generate_default_policy = fsdp2_utils.generate_default_policy - check_param_tying = fsdp2_utils.check_param_tying -else: - apply_ac = lambda *args, **kwargs: None - parallelize_model = lambda *args, **kwargs: None - legalize_param_sharing_between_modules = lambda *args, **kwargs: None - get_standalone_and_tied_modules = lambda *args, **kwargs: ([], set()) - _recursive_apply_fully_shard = lambda *args, **kwargs: None - _generate_default_policy = lambda *args, **kwargs: lambda *args, **kwargs: None - check_param_tying = lambda *args, **kwargs: None - - -def fsdp2_context(func: Callable) -> Optional[Callable]: - """Decorator to run tests with models initialized on the meta device for torch version 2.6+.""" - func = pytest.mark.skipif(SKIP_TEST, reason='Skipping test for torch version < 2.6.0')(func) - func = pytest.mark.filterwarnings('ignore:FSDP2 Config/APIs are experimental*:UserWarning')(func) - return func diff --git a/tests/trainer/test_activation_checkpointing.py b/tests/trainer/test_activation_checkpointing.py index d4bae3a43c..f3b2dcf2f7 100644 --- a/tests/trainer/test_activation_checkpointing.py +++ b/tests/trainer/test_activation_checkpointing.py @@ -11,21 +11,17 @@ OffloadWrapper, ) +from composer.distributed.activation_checkpointing import apply_ac from tests.common import ( ComposerCounterModel, CountModule, world_size, ) -from tests.trainer.fsdp2_context import ( - apply_ac, - fsdp2_context, -) from tests.trainer.test_fsdp2 import create_trainer_with_model @world_size(2) @pytest.mark.gpu -@fsdp2_context @pytest.mark.parametrize( 'activation_checkpointing,expected_forward_count,activation_cpu_offload', [ @@ -106,7 +102,6 @@ def test_fsdp2_activation_checkpointing_attribute( @world_size(2) @pytest.mark.gpu -@fsdp2_context @pytest.mark.parametrize( 'activation_checkpointing,expected_forward_count,activation_cpu_offload', [ @@ -177,7 +172,6 @@ def activation_checkpointing_fn(module: torch.nn.Module) -> bool: @world_size(2) @pytest.mark.gpu -@fsdp2_context @pytest.mark.parametrize('type_of_checkpointing', ['attribute', 'function']) def test_activation_checkpointing_cuda_memory_usage(world_size: int, type_of_checkpointing: str): """Test FSDP2 activation checkpointing CUDA memory usage.""" @@ -247,7 +241,6 @@ def activation_checkpointing_fn(module: torch.nn.Module) -> bool: @world_size(2) @pytest.mark.gpu -@fsdp2_context @pytest.mark.parametrize( 'activation_checkpointing,activation_cpu_offload', [ diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 60e6ee8297..212af9500d 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -19,7 +19,6 @@ import pytest import torch import torch.distributed -from packaging import version from pytest import MonkeyPatch from torch.utils.data import DataLoader, Dataset, DistributedSampler @@ -413,47 +412,6 @@ def test_checkpoint_saver_properly_constructed( else: assert attr == value - # See https://github.com/pytorch/pytorch/issues/133415 - @pytest.mark.xfail - @pytest.mark.skipif( - ( - version.parse(torch.__version__) < version.parse('2.4.0') or - version.parse(torch.__version__) >= version.parse('2.5.0') - ), - reason='Test only applies to PyTorch 2.4.x', - ) - def test_sgd_checkpoint( - self, - tiny_bert_tokenizer, - tmp_path: pathlib.Path, - ): - transformers = pytest.importorskip('transformers') - model = SimpleTransformerMaskedLM(vocab_size=tiny_bert_tokenizer.vocab_size) - pretraining_train_dataset = RandomTextLMDataset( - size=100, - vocab_size=tiny_bert_tokenizer.vocab_size, - sequence_length=1, - use_keys=True, - ) - - collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer, mlm_probability=0.15) - dataloader = DataLoader( - pretraining_train_dataset, - batch_size=1, - sampler=dist.get_sampler(pretraining_train_dataset), - collate_fn=collator, - ) - - trainer = Trainer( - model=model, - optimizers=torch.optim.SGD(model.parameters(), lr=0.1), - train_dataloader=dataloader, - max_duration='5ba', - save_interval='1ba', - save_folder=str(tmp_path / 'checkpoints'), - ) - trainer.fit() - @pytest.mark.parametrize('save_interval', ['1tok', '64tok', '65tok']) @pytest.mark.parametrize('batch_size', [1, 4]) @pytest.mark.parametrize('sequence_length', [1, 16]) @@ -1055,11 +1013,6 @@ def test_strict_errors(self, missing_key: bool, unexpected_key: bool): last_checkpoint = os.path.join('first', 'ep2.pt') if missing_key or unexpected_key: message = r'Error\(s\) in loading state_dict' - if version.parse(torch.__version__) < version.parse('2.2.3') or ( - version.parse(torch.__version__) < version.parse('2.4.0') and not dist.is_initialized() - ): - # Composer implements strict for older torch versions - message = 'Failed to load checkpoint due to' error_context = pytest.raises(RuntimeError, match=message) else: error_context = contextlib.nullcontext() @@ -1415,10 +1368,6 @@ def test_autoload_algorithm_old_checkpoint(self): NoOpModel.__init__ = lambda self, x: None # type: ignore NoOpModel.__repr__ = lambda self: 'NoOpModel(3)' error_context = pytest.raises(KeyError, match='module.0.weight') - if version.parse(torch.__version__) < version.parse('2.2.3') or ( - version.parse(torch.__version__) < version.parse('2.4.0') and not dist.is_initialized() - ): - error_context = pytest.raises(ValueError, match='loaded state dict contains a parameter group.*') with pytest.warns(UserWarning, match='required_on_load algorithm.*'), error_context: trainer_3 = self.get_trainer(load_path=os.path.join('first', 'ep1.pt')) trainer_3.fit(duration='1ba') diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index dfe711b92f..3de2c0ed9b 100644 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -6,7 +6,6 @@ import pytest import torch import torch.nn as nn -from packaging import version from torch import Tensor from torch.utils.data import DataLoader @@ -66,8 +65,6 @@ def test_ddp_sync_strategy( rank_zero_seed: int, request: pytest.FixtureRequest, ): - if version.parse(torch.__version__) < version.parse('2.4.0'): - pytest.skip('Before PyTorch 2.4, single_auto_sync did not properly run on last microbatch') original_model = MinimalConditionalModel() optimizer = torch.optim.SGD(original_model.parameters(), 0.1) device = None diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 0bbe9aee55..5bb2f0d284 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -6,8 +6,7 @@ import pytest import torch -from packaging import version -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper, OffloadWrapper from torch.utils.data import DataLoader, Dataset from composer.models import ComposerClassifier, ComposerModel @@ -331,10 +330,6 @@ def test_fsdp_automicrobatching_sync_hooks(world_size: int): @pytest.mark.gpu @world_size(2) -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2'), - reason='FSDP use_orig_params requires torch 2.0 or higher', -) def test_fsdp_subset_of_params_in_opt(world_size: int): model = SimpleModel() dataset = RandomClassificationDataset(size=10) @@ -436,23 +431,21 @@ def test_fsdp_act_ckpt_offload( ) assert trainer.state.fsdp_enabled - if version.parse(torch.__version__) > version.parse('2.1.0.dev'): - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import OffloadWrapper - - assert isinstance(trainer.state.model.fc1, torch.nn.Module) - - if activation_checkpointing and activation_cpu_offload: - assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, OffloadWrapper) - assert isinstance( - trainer.state.model.fc1._fsdp_wrapped_module._checkpoint_wrapped_module, - CheckpointWrapper, - ) - elif activation_checkpointing: - assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper) - elif activation_cpu_offload: - assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, OffloadWrapper) - else: - assert not isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper) + + assert isinstance(trainer.state.model.fc1, torch.nn.Module) + + if activation_checkpointing and activation_cpu_offload: + assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, OffloadWrapper) + assert isinstance( + trainer.state.model.fc1._fsdp_wrapped_module._checkpoint_wrapped_module, + CheckpointWrapper, + ) + elif activation_checkpointing: + assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper) + elif activation_cpu_offload: + assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, OffloadWrapper) + else: + assert not isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper) @pytest.mark.gpu diff --git a/tests/trainer/test_fsdp2.py b/tests/trainer/test_fsdp2.py index e27921fa3f..1f847437dc 100644 --- a/tests/trainer/test_fsdp2.py +++ b/tests/trainer/test_fsdp2.py @@ -20,7 +20,6 @@ SimpleWeightTiedModel, world_size, ) -from tests.trainer.fsdp2_context import fsdp2_context _INIT_DEVICES = ['cuda', 'meta'] @@ -65,7 +64,6 @@ def create_trainer_with_model( @pytest.mark.parametrize('device', _INIT_DEVICES) @world_size(2) @pytest.mark.gpu -@fsdp2_context def test_fsdp2_initialization_with_tied_params( model_class: type, device: str, @@ -112,7 +110,6 @@ def test_fsdp2_initialization_with_tied_params( @pytest.mark.parametrize('device', _INIT_DEVICES) @world_size(2) @pytest.mark.gpu -@fsdp2_context def test_fsdp2_checkpointing( model_class: type, device: str, @@ -161,7 +158,6 @@ def test_fsdp2_checkpointing( @world_size(2) @pytest.mark.gpu -@fsdp2_context def test_fsdp2_load_from_fsdp1( world_size: int, tmp_path: pathlib.Path, @@ -212,7 +208,6 @@ def test_fsdp2_load_from_fsdp1( @world_size(2) @pytest.mark.gpu -@fsdp2_context @pytest.mark.parametrize('case', ['all_params_one_group', 'subset_one_group', 'multiple_groups']) @pytest.mark.parametrize('device', _INIT_DEVICES) def test_fsdp2_optimizer_handling( @@ -289,7 +284,6 @@ def validate_optimizer_state(current_optimizer: torch.optim.Optimizer, stage: st @world_size(2) @pytest.mark.gpu -@fsdp2_context def test_fsdp2_optimizer_raises_error_when_optimizer_modules_dont_match( world_size: int, ): diff --git a/tests/trainer/test_fsdp2_config.py b/tests/trainer/test_fsdp2_config.py index 7e49233d55..283649a6bd 100644 --- a/tests/trainer/test_fsdp2_config.py +++ b/tests/trainer/test_fsdp2_config.py @@ -4,12 +4,8 @@ import pytest from composer.utils.parallelism import FSDP2Config -from tests.trainer.fsdp2_context import ( - fsdp2_context, -) -@fsdp2_context def test_fsdp2_config(): """Test that FSDP2Config read-only properties work as expected.""" # Create a config instance @@ -47,7 +43,6 @@ def test_fsdp2_config(): assert config.reshard_after_forward is False -@fsdp2_context def test_fsdp2config_from_fsdp1_valid_attributes(): """Test FSDP2Config.from_compatible_attrs with valid attributes.""" valid_attrs = { @@ -64,7 +59,6 @@ def test_fsdp2config_from_fsdp1_valid_attributes(): assert fsdp2_config.device_mesh is None -@fsdp2_context def test_fsdp2config_from_empty_attributes(): """Test FSDP2Config.from_compatible_attrs with empty attributes.""" empty_attrs = {} @@ -76,7 +70,6 @@ def test_fsdp2config_from_empty_attributes(): assert fsdp2_config.device_mesh is None # default value -@fsdp2_context def test_fsdp2config_from_fsdp1_multiple_invalid_attributes(): """Test FSDP2Config.from_compatible_attrs with multiple invalid attributes issues warnings for each.""" mixed_attrs = { diff --git a/tests/trainer/test_fsdp2_gradscaler.py b/tests/trainer/test_fsdp2_gradscaler.py index a08ab43cc9..a1c8f50f3d 100644 --- a/tests/trainer/test_fsdp2_gradscaler.py +++ b/tests/trainer/test_fsdp2_gradscaler.py @@ -9,11 +9,11 @@ from torch.amp.grad_scaler import GradScaler from torch.distributed._tensor import DTensor +from composer.distributed.prepare_distributed import parallelize_model from composer.utils.parallelism import FSDP2Config from tests.common import ( world_size, ) -from tests.trainer.fsdp2_context import fsdp2_context, parallelize_model class SimpleModel(nn.Module): @@ -36,7 +36,6 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: @world_size(2) @pytest.mark.gpu -@fsdp2_context def test_fsdp2_with_gradscaler_inf(world_size: int): """Test FSDP2 with GradScaler can handle inf in grad.""" # Setup GradScaler for mixed precision training diff --git a/tests/trainer/test_fsdp2_helpers.py b/tests/trainer/test_fsdp2_helpers.py index f6d6abd8ad..04dd501123 100644 --- a/tests/trainer/test_fsdp2_helpers.py +++ b/tests/trainer/test_fsdp2_helpers.py @@ -7,8 +7,7 @@ import torch import torch.nn as nn -from tests.trainer.fsdp2_context import ( - fsdp2_context, +from composer.distributed.fsdp2_utils import ( get_standalone_and_tied_modules, legalize_param_sharing_between_modules, ) @@ -17,7 +16,6 @@ def _context(func: Callable) -> Optional[Callable]: """Decorator to run tests with models initialized on the meta device for torch version 2.6+.""" - @fsdp2_context def wrapper(*args, **kwargs): with torch.device('meta'): return func(*args, **kwargs) diff --git a/tests/trainer/test_fsdp2_wrap_fsdp_attribute.py b/tests/trainer/test_fsdp2_wrap_fsdp_attribute.py index 5909954786..86fede27f7 100644 --- a/tests/trainer/test_fsdp2_wrap_fsdp_attribute.py +++ b/tests/trainer/test_fsdp2_wrap_fsdp_attribute.py @@ -6,15 +6,16 @@ import torch.nn as nn from torch.distributed._tensor import DTensor -from composer.utils.parallelism import FSDP2Config -from tests.common import world_size -from tests.trainer.fsdp2_context import ( - _generate_default_policy, +from composer.distributed.fsdp2 import ( _recursive_apply_fully_shard, +) +from composer.distributed.fsdp2_utils import ( check_param_tying, - fsdp2_context, - parallelize_model, + generate_default_policy, ) +from composer.distributed.prepare_distributed import parallelize_model +from composer.utils.parallelism import FSDP2Config +from tests.common import world_size class NestedModule(nn.Module): @@ -60,7 +61,6 @@ def check_dtensors(params: list[torch.nn.Parameter]): # └── M7 (Linear) -@fsdp2_context @world_size(2) def test_fsdp_wrap_separate_modules(world_size: int): """Test FSDP wrapping applied to separate, non-overlapping modules (M2, M5).""" @@ -74,7 +74,6 @@ def test_fsdp_wrap_separate_modules(world_size: int): check_dtensors(list(m1.parameters())) -@fsdp2_context @world_size(2) def test_fsdp_wrap_error_tied_siblings(world_size: int): """Test error when siblings (M3, M4) with tied weights are both marked for FSDP wrapping.""" @@ -89,7 +88,6 @@ def test_fsdp_wrap_error_tied_siblings(world_size: int): parallelize_model(m1, fsdp2_config, opt) -@fsdp2_context @world_size(2) def test_fsdp_wrap_error_tied_sibling_one_wrapped(world_size: int): """Test error when one module (M3) marked for FSDP wrap shares weights with a sibling (M4).""" @@ -103,7 +101,6 @@ def test_fsdp_wrap_error_tied_sibling_one_wrapped(world_size: int): parallelize_model(m1, fsdp2_config, opt) -@fsdp2_context @world_size(2) def test_fsdp_wrap_ancestor_with_tied_children(world_size: int): """Test wrapping an ancestor (M2) whose children (M3, M4) have tied weights. Should succeed.""" @@ -119,7 +116,6 @@ def test_fsdp_wrap_ancestor_with_tied_children(world_size: int): assert id(m1.m2.m3.weight) == id(m1.m2.m4.weight), error_msg -@fsdp2_context @world_size(2) def test_fsdp_wrap_error_tied_across_branches_ancestor_wrap(world_size: int): """Test error when a wrapped module (M2) has a child (M3) tied to a module (M6) in another branch.""" @@ -133,7 +129,6 @@ def test_fsdp_wrap_error_tied_across_branches_ancestor_wrap(world_size: int): parallelize_model(m1, fsdp2_config, opt) -@fsdp2_context @world_size(2) def test_fsdp_wrap_root_with_tied_descendants(world_size: int): """Test wrapping the root (M1) when descendants (M3, M6) have tied weights. Should succeed.""" @@ -148,7 +143,6 @@ def test_fsdp_wrap_root_with_tied_descendants(world_size: int): assert id(m1.m2.m3.weight) == id(m1.m5.m6.weight), error_msg # type: ignore -@fsdp2_context @world_size(2) def test_fsdp_manual_policy_submodule_only(world_size: int): """Test manual policy application where only a submodule (M3) is wrapped.""" @@ -156,7 +150,7 @@ def test_fsdp_manual_policy_submodule_only(world_size: int): m1._fsdp_wrap = False # type: ignore # Ensure root is not wrapped by default policy m1.m2._fsdp_wrap = False # type: ignore # Ensure intermediate is not wrapped m1.m2.m3._fsdp_wrap = True # type: ignore # Target module for wrapping - auto_wrap_policy = _generate_default_policy(m1) + auto_wrap_policy = generate_default_policy(m1) target_modules_to_kwargs = auto_wrap_policy._run_policy(root_module=m1, ignored_modules=set(), root_kwargs={}) _recursive_apply_fully_shard(m1, m1, set(), target_modules_to_kwargs) # Check only m1.m2.m3 parameters are DTensors @@ -165,7 +159,6 @@ def test_fsdp_manual_policy_submodule_only(world_size: int): check_not_dtensors(other_params) -@fsdp2_context @world_size(2) def test_fsdp_wrap_parent_shares_with_child_parent_wrap(world_size: int): """Test wrapping a parent (M2) that shares weights with its child (M3). Should succeed.""" @@ -185,7 +178,6 @@ def test_fsdp_wrap_parent_shares_with_child_parent_wrap(world_size: int): assert id(m2_final.weight) == id(m3_final.weight), error_msg # type: ignore -@fsdp2_context @world_size(2) def test_fsdp_wrap_error_parent_child_share_both_wrap(world_size: int): """Test error when parent (M2) and child (M3) share weights and both are marked for wrapping.""" @@ -201,7 +193,6 @@ def test_fsdp_wrap_error_parent_child_share_both_wrap(world_size: int): parallelize_model(m1, fsdp2_config, opt) -@fsdp2_context @world_size(2) def test_fsdp_wrap_error_parent_child_share_child_wrap(world_size: int): """Test error when parent (M2) and child (M3) share weights and only the child is marked for wrapping.""" @@ -216,7 +207,6 @@ def test_fsdp_wrap_error_parent_child_share_child_wrap(world_size: int): parallelize_model(m1, fsdp2_config, opt) -@fsdp2_context @world_size(2) def test_fsdp_wrap_error_complex_sharing_parent_wrap(world_size: int): """Test error with complex sharing: parent (M2) shares with child (M3), child shares with other branch (M6), parent is wrapped.""" @@ -232,7 +222,6 @@ def test_fsdp_wrap_error_complex_sharing_parent_wrap(world_size: int): parallelize_model(m1, fsdp2_config, opt) -@fsdp2_context @world_size(2) def test_fsdp_wrap_error_tied_across_branches_one_wrap(world_size: int): """Test error (as noted in original comments) when weights are tied (M3, M6) but only one (M6) is marked for wrap.""" @@ -248,7 +237,6 @@ def test_fsdp_wrap_error_tied_across_branches_one_wrap(world_size: int): parallelize_model(m1, fsdp2_config, opt) -@fsdp2_context @world_size(2) def test_fsdp_wrap_fn_base(world_size: int): """Test that a custom wrap function can be provided to the FSDP2Config.""" @@ -265,7 +253,6 @@ def wrap_fn(module: nn.Module): check_dtensors(list(m1.parameters())) -@fsdp2_context @world_size(2) def test_fsdp_wrap_fn_invalid_keys(world_size: int): """Test that an error is raised if the wrap function returns a dict with invalid keys.""" @@ -281,7 +268,6 @@ def wrap_fn(module: nn.Module): parallelize_model(m1, fsdp2_config, opt) -@fsdp2_context @world_size(2) def test_fsdp_wrap_fn_target_module(world_size: int): """Test that if root module is not wrapped, the wrap function will wrap the target module.""" @@ -301,7 +287,6 @@ def wrap_fn(module: nn.Module): check_not_dtensors(other_params) -@fsdp2_context @world_size(2) def test_fsdp_wrap_fn_error_tied_siblings(world_size: int): """Test that if weights are tied and wrapped incorrectly, an error is raised.""" @@ -318,7 +303,6 @@ def wrap_fn(module: nn.Module): parallelize_model(m1, fsdp2_config, opt) -@fsdp2_context @world_size(2) def test_fsdp_wrap_fn_reshard_after_forward(world_size: int): """Test if the wrap function can correctly return a dictionary of kwargs.""" @@ -343,7 +327,6 @@ def check_reshard_after_forward(module: nn.Module): check_reshard_after_forward(module) -@fsdp2_context def test_check_param_tying(): """Test that if weights are tied different before and after the context, an error is raised.""" m1 = DeepNestedModel() @@ -357,7 +340,6 @@ def update_model(m1): update_model(m1) -@fsdp2_context @world_size(2) def test_check_param_tying_fsdp_wrap(world_size: int): """Test that if weights are tied different before and after the context, an error is raised.""" @@ -375,7 +357,6 @@ def update_model(m1): update_model(m1) -@fsdp2_context @world_size(2) def test_fsdp_wrap_attribute_warning(world_size: int): """Test that only one warning is raised if _fsdp_wrap is set in the model instead of fsdp_wrap_fn.""" @@ -393,7 +374,6 @@ def test_fsdp_wrap_attribute_warning(world_size: int): assert len(record) == 1, f'Expected exactly one warning but got {len(record)} warnings' -@fsdp2_context @world_size(2) def test_fsdp_wrap_attribute_not_raised_when_using_fsdp_wrap_fn(world_size: int): """Test that a warning is not raised when using fsdp_wrap_fn.""" diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index a940cb079c..21fbb015bc 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -315,10 +315,8 @@ def test_fsdp_full_state_dict_load( use_tp: bool, use_hsdp: bool, ): - if use_hsdp and version.parse(torch.__version__) < version.parse('2.4.0'): - pytest.xfail('HSDP requires torch 2.4.0 or later') if use_tp: - pytest.skip('TP on PyTorch 2.3 has full state dict issues.') + pytest.skip('TP has full state dict issues.') if autoresume: run_name = 'my-cool-autoresume-run' else: @@ -538,14 +536,6 @@ def test_fsdp_load_old_checkpoint( if composer_version == '0.18.1' and state_dict_type == 'full' and precision == 'amp_bf16' and sharding_strategy == 'FULL_SHARD': pytest.skip('TODO: This checkpoint is missing') - if (composer_version in ['0.22.0', '0.23.0'] and version.parse(torch.__version__) < version.parse('2.3.0') - ) or (composer_version in ['0.24.0', '0.25.0'] and - version.parse(torch.__version__) < version.parse('2.4.0')) or ( - composer_version in ['0.26.0', '0.27.0', '0.28.0'] and - version.parse(torch.__version__) < version.parse('2.5.0') - ) or (composer_version in ['0.29.0'] and version.parse(torch.__version__) < version.parse('2.6.0')): - pytest.skip('Current torch version is older than torch version that checkpoint was written with.') - if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']: if state_dict_type == 'sharded': pytest.skip('Loading legacy sharded checkpoints are not supported after v0.25.0.') @@ -848,12 +838,9 @@ def test_fsdp_partitioned_state_dict_load( request, ): if use_tp: - pytest.skip('TP on PyTorch 2.3 has sharded state dict issues.') + pytest.skip('TP has sharded state dict issues.') if weights_only and autoresume: pytest.skip('Weights only with autoresume is not supported') - if (use_tp or use_hsdp) and version.parse(torch.__version__) < version.parse('2.3.0'): - dist.barrier() # Sync to avoid race conditions on cleaning up tmp_path - pytest.skip('HSDP and TP require torch 2.3.0 or later') load_ignore_keys = [] if load_ignore_keys is None else load_ignore_keys @@ -1120,7 +1107,6 @@ def test_cleanup_sharded_checkpoints( @world_size(2) @pytest.mark.parametrize('weights_only', [False, True]) @pytest.mark.parametrize('planner', [None, 'rename']) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2.0 or higher') @pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') @pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning') def test_fsdp_planner( @@ -1135,40 +1121,22 @@ def test_fsdp_planner( from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, DefaultSavePlanner from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata - if version.parse(torch.__version__) < version.parse('2.4.0'): - - class RenameSavePlanner(DefaultSavePlanner): # type: ignore - - def set_up_planner( - self, - state_dict: STATE_DICT_TYPE, - is_coordinator: bool = False, - ) -> None: - # suffix all keys with `foo_`` - state_dict['state']['model'] = {k + '_foo': v for k, v in state_dict['state']['model'].items()} + class RenameSavePlanner(DefaultSavePlanner): # type: ignore - super().set_up_planner( - state_dict=state_dict, - is_coordinator=is_coordinator, - ) - else: - - class RenameSavePlanner(DefaultSavePlanner): # type: ignore - - def set_up_planner( - self, - state_dict: STATE_DICT_TYPE, - storage_meta=None, - is_coordinator: bool = False, - ) -> None: - # suffix all keys with `foo_`` - state_dict['state']['model'] = {k + '_foo': v for k, v in state_dict['state']['model'].items()} + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta=None, + is_coordinator: bool = False, + ) -> None: + # suffix all keys with `foo_`` + state_dict['state']['model'] = {k + '_foo': v for k, v in state_dict['state']['model'].items()} - super().set_up_planner( - state_dict=state_dict, - storage_meta=storage_meta, - is_coordinator=is_coordinator, - ) + super().set_up_planner( + state_dict=state_dict, + storage_meta=storage_meta, + is_coordinator=is_coordinator, + ) class RenameLoadPlanner(DefaultLoadPlanner): diff --git a/tests/trainer/test_fsdp_param_groups.py b/tests/trainer/test_fsdp_param_groups.py index 8db22c609f..1116c0c02a 100644 --- a/tests/trainer/test_fsdp_param_groups.py +++ b/tests/trainer/test_fsdp_param_groups.py @@ -5,7 +5,6 @@ import pytest import torch -from packaging import version from torch.utils.data import DataLoader from composer.optim import DecoupledSGDW @@ -55,10 +54,6 @@ def test_fsdp_param_groups_without_orig_params(mixed_precision: str, device: str @pytest.mark.filterwarnings('ignore::UserWarning') @device('gpu') @world_size(2) -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2'), - reason='FSDP use_orig_params requires torch 2.0 or higher', -) def test_fsdp_with_param_groups(mixed_precision: str, device: str, reentrant: bool, world_size: int): """ Test whether an optimizer with multiple param groups maintains the same param groups when @@ -128,10 +123,6 @@ def test_fsdp_with_param_groups(mixed_precision: str, device: str, reentrant: bo @pytest.mark.filterwarnings('ignore::UserWarning') @device('gpu') @world_size(2) -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2'), - reason='FSDP use_orig_params requires torch 2.0 or higher', -) def test_fsdp_with_param_groups_with_subset_of_params_in_opt( mixed_precision: str, device: str, diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index 07f1ad1b37..89cd4fa43c 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -7,7 +7,6 @@ import numpy as np import pytest import torch -from packaging import version from torch.distributed._tensor import Replicate from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Dataset @@ -215,7 +214,6 @@ def get_stats(trainer: Trainer) -> dict[str, np.ndarray]: @pytest.mark.gpu @world_size(4) @pytest.mark.parametrize('replication', [2]) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+') @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') def test_tp_forwards_backwards_correctness(world_size: int, replication: int): """Test that training with DDP, FSDP, TP-FSDP results in the same: @@ -265,7 +263,6 @@ def test_tp_forwards_backwards_correctness(world_size: int, replication: int): @world_size(4) @pytest.mark.parametrize('replication', [2]) @pytest.mark.parametrize('batch_size', [1, 4]) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+') @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') def test_tp_fit_correctness(world_size: int, batch_size: int, replication: int): """Test that training with DDP, FSDP, TP-FSDP results in the same: @@ -313,7 +310,6 @@ def test_tp_fit_correctness(world_size: int, batch_size: int, replication: int): @pytest.mark.gpu @world_size(4) @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') @pytest.mark.parametrize('tensor_parallel_degree', [1, 2]) def test_tp_train(world_size: int, tensor_parallel_degree: int): from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel @@ -357,7 +353,6 @@ def test_tp_train(world_size: int, tensor_parallel_degree: int): @pytest.mark.gpu @world_size(4) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') def test_tp_with_param_groups(world_size: int): from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel @@ -398,7 +393,6 @@ def test_tp_with_param_groups(world_size: int): @world_size(4) @pytest.mark.gpu -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') def test_tp_with_subset_of_params(world_size: int): from torch.distributed.tensor.parallel import ColwiseParallel diff --git a/tests/utils/test_inference.py b/tests/utils/test_inference.py index 69b78ead4c..d17ba691cd 100644 --- a/tests/utils/test_inference.py +++ b/tests/utils/test_inference.py @@ -12,7 +12,6 @@ import pytest import torch import torch.nn as nn -from packaging import version from torch.utils.data import DataLoader from composer.core import State @@ -112,9 +111,6 @@ def test_huggingface_export_for_inference_onnx(onnx_opset_version, tiny_bert_con pytest.importorskip('onnxruntime') pytest.importorskip('transformers') - if onnx_opset_version == None and version.parse(torch.__version__) < version.parse('1.13'): - pytest.skip("Don't test prior PyTorch version's default Opset version.") - import onnx import onnx.checker import onnxruntime as ort @@ -222,9 +218,6 @@ def test_export_for_inference_onnx(model_cls, sample_input, onnx_opset_version, pytest.importorskip('onnx') pytest.importorskip('onnxruntime') - if onnx_opset_version == None and version.parse(torch.__version__) < version.parse('1.13'): - pytest.skip("Don't test prior PyTorch version's default Opset version.") - pytest.xfail( 'torch.onnx.errors.UnsupportedOperatorError: Exporting the operator "aten::unflatten" to ONNX opset version 14 is not supported.', ) @@ -347,9 +340,6 @@ def test_export_for_inference_onnx_ddp(model_cls, sample_input, onnx_opset_versi 'torch.onnx.errors.UnsupportedOperatorError: Exporting the operator "aten::unflatten" to ONNX opset version 14 is not supported.', ) - if onnx_opset_version == None and version.parse(torch.__version__) < version.parse('1.13'): - pytest.skip("Don't test prior PyTorch version's default Opset version.") - import onnx import onnx.checker import onnxruntime as ort