Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions composer/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import psutil
import torch
from packaging import version

import composer
from composer.loggers.mosaicml_logger import (
Expand Down Expand Up @@ -321,10 +320,7 @@ def _launch_processes(
log.info('Distributed KV store: tcp://%s:%s', master_addr, master_port)

nccl_env_variable = {
(
'NCCL_ASYNC_ERROR_HANDLING' if version.parse(torch.__version__) < version.parse('2.2.0') else 'TORCH_NCCL_ASYNC_ERROR_HANDLING'
):
'1',
'TORCH_NCCL_ASYNC_ERROR_HANDLING': '1',
}

for local_rank in range(nproc):
Expand Down
259 changes: 83 additions & 176 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
from collections import OrderedDict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Sequence, Union, cast
from unittest.mock import MagicMock

import numpy as np
import torch
import torch.nn.modules.utils
from packaging import version
from torch.amp.grad_scaler import GradScaler
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
Expand All @@ -31,11 +38,6 @@
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import GradScaler # type: ignore
else:
from torch.cuda.amp.grad_scaler import GradScaler # type: ignore

from composer.core.data_spec import DataSpec
from composer.core.event import Event
from composer.core.precision import Precision
Expand Down Expand Up @@ -126,6 +128,10 @@ def fsdp_get_optim_state_dict(
) -> dict[str, Any]:
"""Materializes a given model's optimizer's state_dict.

.. warning::
This function is deprecated and will be removed in Composer version 0.32.
It is maintained for backwards compatibility with tests.

Args:
model (torch.nn.Module): The model that the optimizer corresponds to.
optim (torch.optim.Optimizer): The optimizer that you want a state dict for.
Expand All @@ -140,6 +146,12 @@ def fsdp_get_optim_state_dict(
Returns:
dict[str, Any]: The state_dict for the given optimizer.
"""
warnings.warn(
'fsdp_get_optim_state_dict is deprecated and will be removed in Composer version 0.32',
DeprecationWarning,
stacklevel=2,
)

with fsdp_state_dict_type_context(module=model, state_dict_type=state_dict_type):
return FSDP.optim_state_dict(model, optim) # type: ignore

Expand Down Expand Up @@ -200,10 +212,6 @@ def _create_device_mesh(
fsdp_config: Optional[FSDPConfig | FSDP2Config],
tp_config: Optional[TPConfig],
) -> Optional[DeviceMesh]:
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
# Device mesh has correctness issues before torch 2.3.0
return None

if fsdp_config is None:
return None

Expand Down Expand Up @@ -986,37 +994,22 @@ def get_model_state_dict(self) -> dict[str, Any]:
Returns:
dict[str, Any]: The state dict for the model.
"""
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(
f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set '
'fsdp_state_dict_type to None, "full", or "sharded".',
),
)

model_state_dict = get_model_state_dict(
model=self.model,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=self.fsdp_enabled,
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(
f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
'torch version >= 2.4.0. Please set fsdp_state_dict_type to None, "full", or "sharded".',
),
)
else:
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
model_state_dict = self.model.state_dict()
else:
model_state_dict = self.model.state_dict()

# If model is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail
if self.is_model_ddp:
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.')
model_state_dict = get_model_state_dict(
model=self.model,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=self.fsdp_enabled,
),
)

return model_state_dict

Expand All @@ -1026,40 +1019,25 @@ def get_optim_state_dict(self) -> dict[str, Any]:
Returns:
dict[str, Any]: The state dict for the optimizer.
"""
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(
f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set '
'fsdp_state_dict_type to None, "full", or "sharded".',
),
)

optimizer = ensure_tuple(self.optimizers)[0]
optim_state_dict = get_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=self.fsdp_enabled,
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(
f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
'torch version >= 2.4.0. Please set fsdp_state_dict_type to None, "full", or "sharded".',
),
)
return {type(optimizer).__qualname__: optim_state_dict}
else:
optimizer = ensure_tuple(self.optimizers)[0]
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
optim_state_dict = {
type(optimizer).__qualname__:
fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type),
}
else:
optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()}
return optim_state_dict

optimizer = ensure_tuple(self.optimizers)[0]
optim_state_dict = get_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=self.fsdp_enabled,
),
)
return {type(optimizer).__qualname__: optim_state_dict}

def state_dict(self) -> dict[str, Any]:
"""Collect the state dicts of our serializable attributes.
Expand Down Expand Up @@ -1338,67 +1316,31 @@ def load_model_state(
model_on_rank = state_dict['model'] is not None

if model_on_rank:
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
try:
set_model_state_dict(
model=self.model,
model_state_dict=state_dict['model'],
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=self.fsdp_enabled,
try:
set_model_state_dict(
model=self.model,
model_state_dict=state_dict['model'],
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=self.fsdp_enabled,
),
)
except AttributeError as e:
# Issue: https://github.com/pytorch/pytorch/issues/127351
if "ShardedTensor' object has no attribute 'placements'" in str(e):
raise RuntimeError(
textwrap.dedent(
'PyTorch DTensor broke backwards compatibility in older checkpoints '
'with ShardedTensor, which is now deprecated. To load old checkpoints, '
'either downgrade to PyTorch <2.3.0 or explicitly pass process groups '
'in the Trainer constructor via '
"`parallelism_config = {'fsdp': {'process_group': 'mod1'}}`. We can "
'provide assistance at https://github.com/mosaicml/composer/issues.',
),
)
except AttributeError as e:
# Issue: https://github.com/pytorch/pytorch/issues/127351
if "ShardedTensor' object has no attribute 'placements'" in str(e):
raise RuntimeError(
textwrap.dedent(
'PyTorch DTensor broke backwards compatibility in older checkpoints '
'with ShardedTensor, which is now deprecated. To load old checkpoints, '
'either downgrade to PyTorch <2.3.0 or explicitly pass process groups '
'in the Trainer constructor via '
"`parallelism_config = {'fsdp': {'process_group': 'mod1'}}`. We can "
'provide assistance at https://github.com/mosaicml/composer/issues.',
),
) from e
else:
raise e
else:
missing_keys, unexpected_keys = [], []
try:
# Load model if it exists
if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_monolith_rank0_only:
log.debug(
f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}',
)
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
missing_keys, unexpected_keys = self.model.load_state_dict(
state_dict['model'],
strict=strict,
)
else:
log.debug(f'Loading model state dict with strict={strict}')
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
except RuntimeError as e:
if 'Missing key(s) in state_dict' in str(e) or 'Unexpected key(s) in state_dict' in str(e):
raise RuntimeError(
textwrap.dedent(
'Failed to load checkpoint due to missing or unexpected keys in state_dict. '
'This is likely due to a change in the model architecture. If this is intentional, '
'you can set load_strict_model_weights=False in the Trainer.',
),
) from e
else:
raise e

if len(missing_keys) > 0:
log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
if len(unexpected_keys) > 0:
log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
) from e
else:
raise e

# If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading
if self.load_monolith_rank0_only:
Expand Down Expand Up @@ -1448,52 +1390,17 @@ def load_optim_state(self, state_dict: dict[str, Any], strict: bool = True):
continue

optim_state_dict = serialized_value[type(optimizer).__qualname__] if serialized_value is not None else None
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict

# optim_state_dict is `None` on non-zero ranks when loading FSDP monolith
# checkpoint on rank 0 only. However, PyTorch modifies the state_dict (producing
# errors) before discarding the output. Accordingly, we mock the state dict.
# See: https://github.com/pytorch/pytorch/issues/125177
if version.parse(torch.__version__) < version.parse('2.4.0'):
optim_state_dict = MagicMock() if optim_state_dict is None else optim_state_dict
set_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
optim_state_dict=optim_state_dict,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=self.fsdp_enabled,
),
)
else:
if self.fsdp_enabled:
assert self.fsdp_state_dict_type is not None # pyright
log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}')
# Loading FSDP monolith on rank 0 only requires FSDP.scatter_full_optim_state_dict
# as the context manager does not seem to pass rank0_only=True for the optimizer config
if self.load_monolith_rank0_only:
optim_state_dict = _legacy_optim_state_dict_to_load(
optim_state_dict=optim_state_dict,
model=self.model,
optim=optimizer,
state_dict_type=self.fsdp_state_dict_type,
)
else:
assert optim_state_dict is not None
with fsdp_state_dict_type_context(module=self.model, state_dict_type=self.fsdp_state_dict_type):
optim_state_dict = FSDP.optim_state_dict_to_load( # type: ignore
optim_state_dict=optim_state_dict, model=self.model, optim=optimizer,
)
assert optim_state_dict is not None
optimizer.load_state_dict(optim_state_dict)
else:
assert optim_state_dict is not None
log.debug(f'Loading optimizer state dict')
optimizer.load_state_dict(optim_state_dict)

set_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
optim_state_dict=optim_state_dict, # type: ignore
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=self.fsdp_enabled,
),
)

def load_state_dict(
self,
Expand Down
8 changes: 1 addition & 7 deletions composer/trainer/_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,9 @@
from typing import Optional, Union

import torch
from packaging import version
from torch.cuda.amp.grad_scaler import GradScaler, OptState
from torch.amp.grad_scaler import GradScaler, OptState, _refresh_per_optimizer_state
from torch.optim import Optimizer

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore
else:
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore

from composer.utils import dist

__all__ = ['ClosureGradScaler']
Expand Down
Loading
Loading