Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions composer/checkpoint/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import torch
import torch.distributed.checkpoint as DCP
from packaging import version
from torch import nn
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer

Expand Down Expand Up @@ -273,7 +273,7 @@ def _load_sharded_model_checkpoint(
model_state_dict = _cast_state_dict_to_precision(model_state_dict, load_options.precision)
# TODO: raise warning for unknown or missing keys.
log.debug(f'Loading sharded state dict into model.')
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
if dist.is_initialized():
torch_set_model_state_dict(
model,
model_state_dict,
Expand Down Expand Up @@ -395,7 +395,7 @@ def _load_sharded_optim_checkpoint(
)
for param_key, param_state_dict in optim_state_dict['state'].items():
optim_state_dict['state'][param_key] = _cast_state_dict_to_precision(param_state_dict, load_options.precision)
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
if dist.is_initialized():
torch_set_optimizer_state_dict(
model,
optim,
Expand Down Expand Up @@ -522,7 +522,7 @@ def torch_set_model_state_dict(
cpu_offload (bool): Whether to offload the state dict to CPU before setting.
sharded_state_dict (bool): Whether the state dict is sharded or not.
"""
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
if dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
try:
set_model_state_dict(
Expand Down Expand Up @@ -569,8 +569,7 @@ def torch_set_optimizer_state_dict(
cpu_offload (bool): Whether to offload the state dict to CPU before setting.
sharded_state_dict (bool): Whether the state dict is sharded or not.
"""
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict
if dist.is_initialized():
set_optimizer_state_dict(
model=model,
optimizers=optim,
Expand Down Expand Up @@ -626,10 +625,7 @@ def download_and_load_sharded_state_dict(
raise

log.debug(f'Loading sharded state dict from {load_path} into memory.')
if version.parse(torch.__version__) < version.parse('2.2.0'):
DCP.load_state_dict(state_dict=state_dict, storage_reader=storage_reader, planner=load_planner)
else:
DCP.load(state_dict=state_dict, storage_reader=storage_reader, planner=load_planner)
DCP.load(state_dict=state_dict, storage_reader=storage_reader, planner=load_planner)
return state_dict


Expand Down
9 changes: 1 addition & 8 deletions composer/checkpoint/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
import torch.distributed.checkpoint as DCP
from packaging import version
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor

Expand Down Expand Up @@ -358,13 +357,7 @@ def _save_sharded_state_dict_to_disk(
f'Starting saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}',
)

# For 2.3.0 and above you can use checkpoint_id, but this version works the best for all versions
# of torch (and makes pyright happier) that we support, so we use it for now.
if version.parse(torch.__version__) < version.parse('2.2.0'):
DCP.save_state_dict(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))
else:
DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))

DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))
log.debug(
f'Finished saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}',
)
Expand Down
5 changes: 2 additions & 3 deletions composer/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from composer.core.evaluator import Evaluator

import torch
from packaging import version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -65,7 +64,7 @@ def get_model_state_dict(
cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict)

log.debug('Extracting model state dict')
if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized():
if dist.is_initialized():
from torch.distributed.checkpoint import state_dict as DCPSD # Distributed Checkpoint State Dict
from torch.distributed.checkpoint.state_dict import StateDictOptions

Expand Down Expand Up @@ -337,7 +336,7 @@ def get_optim_state_dict(

cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict)
log.debug('Extracting optim state dict')
if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized():
if dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
log.debug('Calling torch get_optimizer_state_dict...')
optim_state_dict: dict[str, Any] = get_optimizer_state_dict(
Expand Down
11 changes: 2 additions & 9 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import numpy as np
import torch
import torch.nn.modules.utils
from packaging import version
from torch.amp.grad_scaler import GradScaler
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.checkpoint.state_dict import (
Expand All @@ -25,6 +24,7 @@
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import FSDPModule
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
Expand Down Expand Up @@ -601,8 +601,6 @@ def _validate_parallelism_configs(self):
# Validate TP config
if self.tp_config is not None:
warnings.warn('Tensor parallelism (TP) is experimental and may change in future versions.', FutureWarning)
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
raise ValueError('Tensor parallelism (TP) requires torch>=2.3.0.')
if self.fsdp_config is None:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP to be enabled. '
Expand Down Expand Up @@ -930,13 +928,8 @@ def fsdp_enabled(self):
"""Indicates if FSDP is enabled."""
for module in self.model.modules():
# FSDP is FSDP1, FSDPModule is FSDP2
if isinstance(module, FSDP):
if isinstance(module, (FSDP, FSDPModule)):
return True
# TODO remove this once we deprecate torch 2.5
if version.parse(torch.__version__) >= version.parse('2.6.0'):
from torch.distributed.fsdp._fully_shard import FSDPModule
if isinstance(module, FSDPModule):
return True
return False

@property
Expand Down
3 changes: 0 additions & 3 deletions composer/devices/device_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch
import torch.cuda.amp
import torch.utils.data
from packaging import version

from composer.devices.device import Device

Expand All @@ -28,8 +27,6 @@ class DeviceMPS(Device):
name = 'mps'

def __init__(self) -> None:
if version.parse(torch.__version__) < version.parse('1.12.0'):
raise RuntimeError('Support for MPS device requires torch >= 1.12.')
if not torch.backends.mps.is_available(): # type: ignore (version guarded)
raise RuntimeError('MPS requires MAC OSX >= 12.3')
if not torch.backends.mps.is_built(): # type: ignore (version guarded)
Expand Down
2 changes: 2 additions & 0 deletions composer/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
prepare_fsdp_module,
prepare_tp_module,
)
from composer.distributed.prepare_distributed import parallelize_composer_model

__all__ = [
'DDPSyncStrategy',
'ddp_sync_context',
'prepare_ddp_module',
'prepare_fsdp_module',
'prepare_tp_module',
'parallelize_composer_model',
]
5 changes: 1 addition & 4 deletions composer/distributed/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Any, Callable, ContextManager, Iterator, Optional, Sequence, Union, cast

import torch
from packaging import version
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
Expand Down Expand Up @@ -338,9 +337,7 @@ def sync_hook(*args):
sharding_strategy = SHARDING_MAP[sharding_map_key]

kwargs = {}
if version.parse(
torch.__version__.split('.dev')[0],
) >= version.parse('2.2.0') and fsdp_config.device_mesh is not None:
if fsdp_config.device_mesh is not None:
if fsdp_config.process_group is not None:
warnings.warn(
'process_group and device_mesh are set for FSDP, so ignoring device_mesh. Please set process_group to None.',
Expand Down
Loading
Loading