Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions composer/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import psutil
import torch
from packaging import version

import composer
from composer.loggers.mosaicml_logger import (
Expand Down Expand Up @@ -321,10 +320,7 @@ def _launch_processes(
log.info('Distributed KV store: tcp://%s:%s', master_addr, master_port)

nccl_env_variable = {
(
'NCCL_ASYNC_ERROR_HANDLING' if version.parse(torch.__version__) < version.parse('2.2.0') else 'TORCH_NCCL_ASYNC_ERROR_HANDLING'
):
'1',
'TORCH_NCCL_ASYNC_ERROR_HANDLING': '1',
}

for local_rank in range(nproc):
Expand Down
269 changes: 93 additions & 176 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
from collections import OrderedDict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Sequence, Union, cast
from unittest.mock import MagicMock

import numpy as np
import torch
import torch.nn.modules.utils
from packaging import version
from torch.amp.grad_scaler import GradScaler
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
Expand All @@ -31,11 +38,6 @@
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import GradScaler # type: ignore
else:
from torch.cuda.amp.grad_scaler import GradScaler # type: ignore

from composer.core.data_spec import DataSpec
from composer.core.event import Event
from composer.core.precision import Precision
Expand Down Expand Up @@ -73,6 +75,10 @@
def fsdp_state_dict_type_context(module: torch.nn.Module, state_dict_type: str = 'full'):
"""Context manager for materializing or loading an fsdp module's state dict.

.. warning::
This function is deprecated and will be removed in a future release.
It is maintained for backwards compatibility with tests.

Args:
module (torch.nn.Module): The torch module that you want to call `state_dict()`
or `load_state_dict()` on.
Expand All @@ -87,6 +93,12 @@ def fsdp_state_dict_type_context(module: torch.nn.Module, state_dict_type: str =
"""
# Torch forgot to put ShardedStateDictConfig in torch/distributed/fsdp/__init__.py, so we
# have to import it this way.
warnings.warn(
'fsdp_state_dict_type_context is deprecated and will be removed in a future release',
DeprecationWarning,
stacklevel=2,
)

from torch.distributed.fsdp.fully_sharded_data_parallel import ShardedStateDictConfig

fsdp_state_dict_type = None
Expand Down Expand Up @@ -126,6 +138,10 @@ def fsdp_get_optim_state_dict(
) -> dict[str, Any]:
"""Materializes a given model's optimizer's state_dict.

.. warning::
This function is deprecated and will be removed in a future release.
It is maintained for backwards compatibility with tests.

Args:
model (torch.nn.Module): The model that the optimizer corresponds to.
optim (torch.optim.Optimizer): The optimizer that you want a state dict for.
Expand All @@ -140,6 +156,12 @@ def fsdp_get_optim_state_dict(
Returns:
dict[str, Any]: The state_dict for the given optimizer.
"""
warnings.warn(
'fsdp_get_optim_state_dict is deprecated and will be removed in a future release',
DeprecationWarning,
stacklevel=2,
)

with fsdp_state_dict_type_context(module=model, state_dict_type=state_dict_type):
return FSDP.optim_state_dict(model, optim) # type: ignore

Expand Down Expand Up @@ -200,10 +222,6 @@ def _create_device_mesh(
fsdp_config: Optional[FSDPConfig | FSDP2Config],
tp_config: Optional[TPConfig],
) -> Optional[DeviceMesh]:
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
# Device mesh has correctness issues before torch 2.3.0
return None

if fsdp_config is None:
return None

Expand Down Expand Up @@ -986,37 +1004,22 @@ def get_model_state_dict(self) -> dict[str, Any]:
Returns:
dict[str, Any]: The state dict for the model.
"""
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(
f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set '
'fsdp_state_dict_type to None, "full", or "sharded".',
),
)

model_state_dict = get_model_state_dict(
model=self.model,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=self.fsdp_enabled,
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(
f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
'torch version >= 2.4.0. Please set fsdp_state_dict_type to None, "full", or "sharded".',
),
)
else:
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
model_state_dict = self.model.state_dict()
else:
model_state_dict = self.model.state_dict()

# If model is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail
if self.is_model_ddp:
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.')
model_state_dict = get_model_state_dict(
model=self.model,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=self.fsdp_enabled,
),
)

return model_state_dict

Expand All @@ -1026,40 +1029,25 @@ def get_optim_state_dict(self) -> dict[str, Any]:
Returns:
dict[str, Any]: The state dict for the optimizer.
"""
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(
f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set '
'fsdp_state_dict_type to None, "full", or "sharded".',
),
)

optimizer = ensure_tuple(self.optimizers)[0]
optim_state_dict = get_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=self.fsdp_enabled,
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(
f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
'torch version >= 2.4.0. Please set fsdp_state_dict_type to None, "full", or "sharded".',
),
)
return {type(optimizer).__qualname__: optim_state_dict}
else:
optimizer = ensure_tuple(self.optimizers)[0]
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
optim_state_dict = {
type(optimizer).__qualname__:
fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type),
}
else:
optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()}
return optim_state_dict

optimizer = ensure_tuple(self.optimizers)[0]
optim_state_dict = get_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=self.fsdp_enabled,
),
)
return {type(optimizer).__qualname__: optim_state_dict}

def state_dict(self) -> dict[str, Any]:
"""Collect the state dicts of our serializable attributes.
Expand Down Expand Up @@ -1338,67 +1326,31 @@ def load_model_state(
model_on_rank = state_dict['model'] is not None

if model_on_rank:
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
try:
set_model_state_dict(
model=self.model,
model_state_dict=state_dict['model'],
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=self.fsdp_enabled,
try:
set_model_state_dict(
model=self.model,
model_state_dict=state_dict['model'],
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=self.fsdp_enabled,
),
)
except AttributeError as e:
# Issue: https://github.com/pytorch/pytorch/issues/127351
if "ShardedTensor' object has no attribute 'placements'" in str(e):
raise RuntimeError(
textwrap.dedent(
'PyTorch DTensor broke backwards compatibility in older checkpoints '
'with ShardedTensor, which is now deprecated. To load old checkpoints, '
'either downgrade to PyTorch <2.3.0 or explicitly pass process groups '
'in the Trainer constructor via '
"`parallelism_config = {'fsdp': {'process_group': 'mod1'}}`. We can "
'provide assistance at https://github.com/mosaicml/composer/issues.',
),
)
except AttributeError as e:
# Issue: https://github.com/pytorch/pytorch/issues/127351
if "ShardedTensor' object has no attribute 'placements'" in str(e):
raise RuntimeError(
textwrap.dedent(
'PyTorch DTensor broke backwards compatibility in older checkpoints '
'with ShardedTensor, which is now deprecated. To load old checkpoints, '
'either downgrade to PyTorch <2.3.0 or explicitly pass process groups '
'in the Trainer constructor via '
"`parallelism_config = {'fsdp': {'process_group': 'mod1'}}`. We can "
'provide assistance at https://github.com/mosaicml/composer/issues.',
),
) from e
else:
raise e
else:
missing_keys, unexpected_keys = [], []
try:
# Load model if it exists
if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_monolith_rank0_only:
log.debug(
f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}',
)
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
missing_keys, unexpected_keys = self.model.load_state_dict(
state_dict['model'],
strict=strict,
)
else:
log.debug(f'Loading model state dict with strict={strict}')
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
except RuntimeError as e:
if 'Missing key(s) in state_dict' in str(e) or 'Unexpected key(s) in state_dict' in str(e):
raise RuntimeError(
textwrap.dedent(
'Failed to load checkpoint due to missing or unexpected keys in state_dict. '
'This is likely due to a change in the model architecture. If this is intentional, '
'you can set load_strict_model_weights=False in the Trainer.',
),
) from e
else:
raise e

if len(missing_keys) > 0:
log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
if len(unexpected_keys) > 0:
log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
) from e
else:
raise e

# If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading
if self.load_monolith_rank0_only:
Expand Down Expand Up @@ -1448,52 +1400,17 @@ def load_optim_state(self, state_dict: dict[str, Any], strict: bool = True):
continue

optim_state_dict = serialized_value[type(optimizer).__qualname__] if serialized_value is not None else None
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict

# optim_state_dict is `None` on non-zero ranks when loading FSDP monolith
# checkpoint on rank 0 only. However, PyTorch modifies the state_dict (producing
# errors) before discarding the output. Accordingly, we mock the state dict.
# See: https://github.com/pytorch/pytorch/issues/125177
if version.parse(torch.__version__) < version.parse('2.4.0'):
optim_state_dict = MagicMock() if optim_state_dict is None else optim_state_dict
set_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
optim_state_dict=optim_state_dict,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=self.fsdp_enabled,
),
)
else:
if self.fsdp_enabled:
assert self.fsdp_state_dict_type is not None # pyright
log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}')
# Loading FSDP monolith on rank 0 only requires FSDP.scatter_full_optim_state_dict
# as the context manager does not seem to pass rank0_only=True for the optimizer config
if self.load_monolith_rank0_only:
optim_state_dict = _legacy_optim_state_dict_to_load(
optim_state_dict=optim_state_dict,
model=self.model,
optim=optimizer,
state_dict_type=self.fsdp_state_dict_type,
)
else:
assert optim_state_dict is not None
with fsdp_state_dict_type_context(module=self.model, state_dict_type=self.fsdp_state_dict_type):
optim_state_dict = FSDP.optim_state_dict_to_load( # type: ignore
optim_state_dict=optim_state_dict, model=self.model, optim=optimizer,
)
assert optim_state_dict is not None
optimizer.load_state_dict(optim_state_dict)
else:
assert optim_state_dict is not None
log.debug(f'Loading optimizer state dict')
optimizer.load_state_dict(optim_state_dict)

set_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
optim_state_dict=optim_state_dict, # type: ignore
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=self.fsdp_enabled,
),
)

def load_state_dict(
self,
Expand Down
8 changes: 1 addition & 7 deletions composer/trainer/_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,9 @@
from typing import Optional, Union

import torch
from packaging import version
from torch.cuda.amp.grad_scaler import GradScaler, OptState
from torch.amp.grad_scaler import GradScaler, OptState, _refresh_per_optimizer_state
from torch.optim import Optimizer

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore
else:
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore

from composer.utils import dist

__all__ = ['ClosureGradScaler']
Expand Down
Loading
Loading