Skip to content

Commit 03ac462

Browse files
authored
[FSDP2] Init FSDP2 based checkpointing (#3824)
Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: Daniel King <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Kick start support of FSDP2 checkpointing along with Trainer and State: Support same interface in State.fsdp_config between FSDP2Config and FSDPConfig so both config can be passed to Trainer/State. Currently FSDP2Config supports a minimum default interface methods/properties to make State functional Support FSDP2 based checkpointing by State Support loading checkpoints from FSDP1 artifacts to FSDP2
1 parent 44345ae commit 03ac462

File tree

9 files changed

+427
-89
lines changed

9 files changed

+427
-89
lines changed

composer/core/state.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time
4444
from composer.devices import Device
4545
from composer.utils import (
46+
FSDP2Config,
4647
FSDPConfig,
4748
ParallelismConfig,
4849
ParallelismType,
@@ -196,7 +197,7 @@ def _ensure_backwards_compatible_checkpointing(state_dict: dict[str, Any]):
196197

197198
def _create_device_mesh(
198199
device: Device,
199-
fsdp_config: Optional[FSDPConfig],
200+
fsdp_config: Optional[FSDPConfig | FSDP2Config],
200201
tp_config: Optional[TPConfig],
201202
) -> Optional[DeviceMesh]:
202203
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
@@ -536,7 +537,8 @@ def __init__(
536537

537538
self.profiler: Optional[Profiler] = None
538539

539-
self.fsdp_config = parallelism_config.fsdp if parallelism_config is not None else None
540+
self._fsdp_config = parallelism_config.fsdp if parallelism_config is not None else None
541+
self._fsdp2_config = parallelism_config.fsdp2 if parallelism_config is not None else None
540542
self.tp_config = parallelism_config.tp if parallelism_config is not None else None
541543

542544
self.automicrobatch_fsdp_hook_handles = []
@@ -873,6 +875,27 @@ def evaluators(self):
873875
def evaluators(self, evaluators: Union[Evaluator, Sequence[Evaluator]]):
874876
self._evaluators[:] = list(ensure_tuple(evaluators))
875877

878+
@property
879+
def fsdp_config(self):
880+
"""Returns the appropriate FSDP configuration to use.
881+
882+
Prioritizes FSDP2 config if available, otherwise falls back to FSDP1 config.
883+
"""
884+
return self._fsdp2_config if self._fsdp2_config is not None else self._fsdp_config
885+
886+
# For backward compatibility
887+
@fsdp_config.setter
888+
def fsdp_config(self, value: FSDPConfig | FSDP2Config):
889+
"""Sets the FSDP configuration, handling both FSDP1 and FSDP2 configurations."""
890+
if isinstance(value, FSDPConfig):
891+
self._fsdp_config = value
892+
self._fsdp2_config = None
893+
elif isinstance(value, FSDP2Config):
894+
self._fsdp2_config = value
895+
self._fsdp_config = None
896+
else:
897+
raise TypeError(f'Expected value to be of type FSDPConfig or FSDP2Config, but got {type(value)}.')
898+
876899
@property
877900
def fsdp_enabled(self):
878901
"""Indicates if FSDP is enabled."""
@@ -1384,6 +1407,11 @@ def load_model_state(
13841407
with reproducibility.seed_context(self.rank_zero_seed):
13851408
from composer.distributed import prepare_fsdp_module
13861409

1410+
# TODO (FSDP2): support calling FSDP2 wrapper depending on the config type
1411+
assert isinstance(
1412+
self.fsdp_config,
1413+
FSDPConfig,
1414+
), f'prepare_fsdp_module requires FSDPConfig, got: {type(self.fsdp_config)}'
13871415
self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module(
13881416
self.model,
13891417
self.optimizers,

composer/distributed/fsdp2.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,10 @@
33

44
"""Helpers for FSDP2."""
55

6-
import warnings
7-
from dataclasses import dataclass
8-
from typing import Optional, Union
9-
10-
from torch import nn
11-
from torch.distributed._tensor.device_mesh import DeviceMesh
6+
import torch.nn as nn
127
from torch.distributed.fsdp._fully_shard import fully_shard
13-
from torch.distributed.fsdp._fully_shard._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
14-
158

16-
@dataclass
17-
class FSDP2Config:
18-
"""Configuration for Fully Sharded Data Parallelism (FSDP2).
19-
20-
Args:
21-
device_mesh (Optional[DeviceMesh]): The DeviceMesh for sharding. If None, a default 1D mesh is created.
22-
For 1D mesh, parameters are fully sharded across the mesh (FSDP).
23-
For 2D mesh, parameters are sharded across the 1st dimension and replicated across the 0th dimension (HSDP).
24-
reshard_after_forward (Union[bool, int]): Controls parameter behavior after forward:
25-
- If True, reshards parameters after forward, re-all-gathers in backward.
26-
- If False, keeps unsharded parameters in memory, avoids all-gather in backward.
27-
- If int, reshards to smaller world size after forward.
28-
Default: True
29-
mp_policy (Optional[MixedPrecisionPolicy]): Mixed precision policy. Default: None
30-
offload_policy (Optional[OffloadPolicy]): Offloading policy. Default: None
31-
"""
32-
device_mesh: Optional[DeviceMesh] = None
33-
reshard_after_forward: Union[bool, int] = True
34-
mp_policy: Optional[MixedPrecisionPolicy] = None
35-
offload_policy: Optional[OffloadPolicy] = None
36-
37-
def __post_init__(self):
38-
warnings.warn('FSDP2 Config/APIs are experimental and subject to heavy changes', UserWarning)
9+
from composer.utils.parallelism import FSDP2Config
3910

4011

4112
def get_standalone_and_tied_modules(modules: list[nn.Module]) -> tuple[list[nn.Module], set[nn.Module]]:
@@ -127,6 +98,8 @@ def apply_fully_shard(
12798
) -> None:
12899
"""Applies FSDP2's `fully_shard` to the specified modules and then to the parent model.
129100
101+
NOTE FSDP are only applied to nn.Parameters not Buffers.
102+
130103
Args:
131104
model (torch.nn.Module): The parent model.
132105
independent_submodules (list[torch.nn.Module]): The modules to apply fully_shard to.
@@ -136,22 +109,18 @@ def apply_fully_shard(
136109
None
137110
"""
138111
fully_shard_kwargs = {'mesh': fsdp2_config.device_mesh, 'reshard_after_forward': fsdp2_config.reshard_after_forward}
139-
if fsdp2_config.mp_policy:
140-
fully_shard_kwargs['mp_policy'] = fsdp2_config.mp_policy
141-
if fsdp2_config.offload_policy:
142-
fully_shard_kwargs['offload_policy'] = fsdp2_config.offload_policy
143112

144113
# Apply fully_shard to each module in the list
145114
if len(independent_submodules) == 0:
146115
raise RuntimeError(
147-
"Can't find any submodules to apply FSDP, e.g., the submodules may all have tied weights. Applying FSDP to the root model does not provide any memory savings.",
116+
"Can't find any submodules to apply FSDP, e.g., the submodules may all have tied parameters. Applying FSDP to the root model does not provide any memory savings.",
148117
)
149118

150119
independent_submodules, modules_tied = get_standalone_and_tied_modules(independent_submodules)
151120
if len(modules_tied) > 0:
152121
raise RuntimeError(
153-
'Submodules to be sharded have tied weights. FSDP cannot be applied to modules with tied weights independently. '
154-
'Please ensure that the submodules do not have tied weights.',
122+
'Submodules to be sharded have tied parameters. FSDP cannot be applied to modules with tied parameters independently. '
123+
'Please ensure that the submodules do not have tied parameters.',
155124
)
156125

157126
# NOTE there is a bug fully_shard can not handle when the model has a child module which is the child of another

composer/trainer/trainer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1629,7 +1629,11 @@ def __init__(
16291629
f'Using closures and precision {self.state.precision} is not supported'
16301630
f' with FSDP. Please use another optimizer or precision type.',
16311631
)
1632-
self.state.scaler = ShardedGradScaler()
1632+
if isinstance(self.state.fsdp_config, FSDPConfig):
1633+
# Per TorchTitan doc and FSDP2 test: test_fsdp2_gradscaler.py,
1634+
# GradScaler can already handle state synchronization via torch._amp_foreach_non_finite_check_and_unscale_,
1635+
# so we don't need to use ShardedGradScaler
1636+
self.state.scaler = ShardedGradScaler()
16331637

16341638
# suppressing FSDP warning when auto grad accum exits the forward pass before completing
16351639
warnings.filterwarnings(action='ignore', message='Forward order differs from that of the first iteration')
@@ -1657,6 +1661,11 @@ def __init__(
16571661
if self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and not self.state.load_monolith_rank0_only:
16581662
# Init with globally fixed seed so all HSDP replicas have the same initial weights
16591663
with reproducibility.seed_context(self.state.rank_zero_seed):
1664+
# TODO (FSDP2): support calling FSDP2 wrapper depending on the config type
1665+
assert isinstance(
1666+
self.state.fsdp_config,
1667+
FSDPConfig,
1668+
), f'prepare_fsdp_module requires FSDPConfig, got: {type(self.state.fsdp_config)}'
16601669
self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module(
16611670
model,
16621671
optimizers,
@@ -1790,6 +1799,11 @@ def __init__(
17901799
not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and
17911800
self.state.load_monolith_rank0_only
17921801
):
1802+
# TODO (FSDP2): support calling FSDP2 wrapper depending on the config type
1803+
assert isinstance(
1804+
self.state.fsdp_config,
1805+
FSDPConfig,
1806+
), f'prepare_fsdp_module requires FSDPConfig, got: {type(self.state.fsdp_config)}'
17931807
# Init with globally fixed seed so all HSDP replicas have the same initial weights
17941808
with reproducibility.seed_context(self.state.rank_zero_seed):
17951809
self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module(

composer/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
UCObjectStore,
7676
build_remote_backend,
7777
)
78-
from composer.utils.parallelism import FSDPConfig, ParallelismConfig, TPConfig
78+
from composer.utils.parallelism import FSDP2Config, FSDPConfig, ParallelismConfig, TPConfig
7979
from composer.utils.remote_uploader import RemoteFilesExistingCheckStatus, RemoteUploader
8080
from composer.utils.retrying import retry
8181
from composer.utils.string_enum import StringEnum
@@ -152,6 +152,7 @@
152152
'STR_TO_DTYPE',
153153
'ParallelismType',
154154
'FSDPConfig',
155+
'FSDP2Config',
155156
'TPConfig',
156157
'ParallelismConfig',
157158
'MLFLOW_EXPERIMENT_ID_FORMAT_KEY',

composer/utils/parallelism.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Parallelism configs."""
55

6+
import warnings
67
from dataclasses import dataclass, field
78
from typing import Any, Optional
89

@@ -61,6 +62,72 @@ def device_mesh(self, value: Optional[DeviceMesh]):
6162
self._device_mesh = value
6263

6364

65+
@dataclass
66+
class FSDP2Config:
67+
"""Configuration for Fully Sharded Data Parallelism (FSDP2).
68+
69+
Args:
70+
device_mesh (Optional[DeviceMesh]): The DeviceMesh for sharding. If None, a default 1D mesh is created.
71+
For 1D mesh, parameters are fully sharded across the mesh (FSDP).
72+
For 2D mesh, parameters are sharded across the 1st dimension and replicated across the 0th dimension (HSDP).
73+
reshard_after_forward (Union[bool, int]): Controls parameter behavior after forward.
74+
"""
75+
76+
# Settable core FSDP2 attrs
77+
device_mesh: Optional[DeviceMesh] = None
78+
reshard_after_forward: bool | int = True
79+
80+
### Temporary read-only properties for FSDP 1 compatibility ###
81+
# to be supported in FSDP2
82+
@property
83+
def auto_wrap(self) -> bool:
84+
return False
85+
86+
@property
87+
def load_monolith_rank0_only(self) -> bool:
88+
return False
89+
90+
@property
91+
def sync_module_states(self) -> bool:
92+
return False
93+
94+
@property
95+
def load_planner(self) -> Optional[Any]:
96+
return None
97+
98+
@property
99+
def save_planner(self) -> Optional[Any]:
100+
return None
101+
102+
@property
103+
def sharded_ckpt_prefix_dir(self) -> str:
104+
return 'ep{epoch}-ba{batch}'
105+
106+
@property
107+
def activation_cpu_offload(self) -> bool:
108+
return False
109+
110+
@property
111+
def data_parallel_shard_degree(self) -> int:
112+
return -1
113+
114+
@property
115+
def data_parallel_replicate_degree(self) -> Optional[int]:
116+
return None
117+
118+
# to be deprecated in FSDP2
119+
@property
120+
def state_dict_type(self) -> str:
121+
return 'sharded'
122+
123+
@property
124+
def use_orig_params(self) -> bool:
125+
return True
126+
127+
def __post_init__(self):
128+
warnings.warn('FSDP2 Config/APIs are experimental and subject to heavy changes', UserWarning)
129+
130+
64131
@dataclass
65132
class TPConfig:
66133
"""Configuration for tensor parallelism (TP)."""
@@ -74,3 +141,4 @@ class ParallelismConfig:
74141
"""Configuration for parallelism."""
75142
fsdp: Optional[FSDPConfig] = None
76143
tp: Optional[TPConfig] = None
144+
fsdp2: Optional[FSDP2Config] = None

tests/trainer/fsdp2_context.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2024 MosaicML Composer authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Callable, Optional
5+
6+
import pytest
7+
import torch
8+
from packaging import version
9+
10+
SKIP_TEST = version.parse(torch.__version__) < version.parse('2.6.0')
11+
if not SKIP_TEST:
12+
# TODO (FSDP2) move this to top once we decprecate torch 2.5
13+
from composer.distributed import fsdp2
14+
prepare_fully_shard = fsdp2.prepare_fully_shard
15+
legalize_param_sharing_between_modules = fsdp2.legalize_param_sharing_between_modules
16+
get_standalone_and_tied_modules = fsdp2.get_standalone_and_tied_modules
17+
else:
18+
prepare_fully_shard = lambda *args, **kwargs: None
19+
legalize_param_sharing_between_modules = lambda *args, **kwargs: None
20+
get_standalone_and_tied_modules = lambda *args, **kwargs: ([], set())
21+
22+
23+
def fsdp2_context(func: Callable) -> Optional[Callable]:
24+
"""Decorator to run tests with models initialized on the meta device for torch version 2.6+."""
25+
func = pytest.mark.skipif(SKIP_TEST, reason='Skipping test for torch version < 2.6.0')(func)
26+
func = pytest.mark.filterwarnings('ignore:FSDP2 Config/APIs are experimental*:UserWarning')(func)
27+
return func

0 commit comments

Comments
 (0)