Skip to content

Commit de43ee2

Browse files
committed
updated code
1 parent 5b033b7 commit de43ee2

File tree

2 files changed

+100
-22
lines changed

2 files changed

+100
-22
lines changed

composer/distributed/fsdp2.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import warnings
77
from dataclasses import dataclass
8-
from typing import Optional, Union
8+
from typing import Optional, Union, Sequence
99

10+
import torch
1011
from torch import nn
11-
from torch.optim import Optimizer
1212
from torch.distributed._tensor.device_mesh import DeviceMesh
1313
from torch.distributed.fsdp._fully_shard import fully_shard
1414
from torch.distributed.fsdp._fully_shard._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
1515

16+
from composer.utils import ensure_tuple
1617

1718
@dataclass
1819
class FSDP2Config:
@@ -121,7 +122,7 @@ def _check_param_sharing(module: nn.Module):
121122
_check_param_sharing(model)
122123

123124
def update_optimizer_modules(
124-
optimizer: Optimizer,
125+
optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]],
125126
modules_to_shard: list[nn.Module],
126127
model: nn.Module,
127128
orig_param_id_to_name: dict[int, str],
@@ -135,21 +136,29 @@ def update_optimizer_modules(
135136
model (nn.Module): The parent model that is also sharded.
136137
orig_param_id_to_name (dict[int, str]): Mapping from original parameter IDs to their names.
137138
"""
139+
# Using the same logic as in FSDP1 to address multiple optimizers
140+
optimizers_tuple = ensure_tuple(optimizers)
141+
if len(optimizers_tuple) != 1:
142+
raise NotImplementedError(f'Only one optimizer is supported; found {len(optimizers_tuple)} optimizers')
143+
144+
optimizer = optimizers_tuple[0]
145+
138146
# Build a mapping from parameter name to sharded parameter (after sharding)
139147
name_to_sharded_param = dict(model.named_parameters())
140148
for module in modules_to_shard:
141149
name_to_sharded_param.update(dict(module.named_parameters()))
142150

143151
# Create a mapping from old parameters to new DTensor parameters
152+
# Note: if params are tied and the same parameter is in multiple groups, pytorch will raise an error
144153
old_to_new_param = {}
145154
for group in optimizer.param_groups:
146155
for param in group['params']:
147156
param_name = orig_param_id_to_name.get(id(param))
157+
# Note: the names of the parameters stay the same after sharding so we can do the following.
148158
if param_name is not None and param_name in name_to_sharded_param:
149159
old_to_new_param[param] = name_to_sharded_param[param_name]
150160
else:
151-
# TODO: Look into whether we will ever hit this case...
152-
raise ValueError(f"Parameter {param} not found in model")
161+
raise ValueError(f"The same model must be passed to the optimizer and trainer.")
153162

154163
# Update param groups with new parameters
155164
new_param_groups = []
@@ -223,7 +232,7 @@ def apply_fully_shard(
223232

224233
def prepare_fully_shard(
225234
model: nn.Module,
226-
optimizer: Optional[Optimizer],
235+
optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]],
227236
fsdp2_config: FSDP2Config,
228237
) -> None:
229238
"""Applies FSDP2's `fully_shard` to the model according to given fsdp2_config.
@@ -244,5 +253,6 @@ def prepare_fully_shard(
244253

245254
apply_fully_shard(model, modules_to_shard, fsdp2_config)
246255

247-
# After the model is sharded in place, we can update the optimizer state to use the DTensor parameters
248-
update_optimizer_modules(optimizer, modules_to_shard, model, orig_param_id_to_name)
256+
# If the optimizer is provided, update the optimizer state to use the DTensor parameters
257+
if optimizers is not None:
258+
update_optimizer_modules(optimizers, modules_to_shard, model, orig_param_id_to_name)

tests/trainer/test_fsdp2.py

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from packaging import version
77
from torch.utils.data import DataLoader
8+
89
from composer.models import ComposerClassifier
910
from composer.trainer.trainer import Trainer
1011
from composer.utils import dist
@@ -27,7 +28,6 @@
2728

2829
@pytest.mark.parametrize('model', [SimpleWeightTiedModel, PartialWeightTiedModel])
2930
@pytest.mark.parametrize('device', _INIT_DEVICES)
30-
@pytest.mark.parametrize('optimizer', [torch.optim.Adam, torch.optim.SGD])
3131
@world_size(2)
3232
@pytest.mark.gpu
3333
@pytest.mark.filterwarnings('ignore:FSDP2 Config/APIs are experimental*:UserWarning')
@@ -36,7 +36,6 @@ def test_fsdp2_initialization_with_tied_params(
3636
model: ComposerClassifier,
3737
world_size: int,
3838
device: str,
39-
optimizer: type[torch.optim.Optimizer],
4039
):
4140
"""test FSDP2 initialization for a simple model with weight tying and a model where two modules
4241
from separate submodules have weight tying applied.
@@ -53,26 +52,16 @@ def test_fsdp2_initialization_with_tied_params(
5352
mp_policy=None,
5453
offload_policy=None,
5554
)
56-
optimizer = optimizer(model.parameters(), lr=0.1)
57-
prepare_fully_shard(model=model.module, optimizer=optimizer, fsdp2_config=fsdp2_config)
55+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
56+
prepare_fully_shard(model=model.module, optimizers=optimizer, fsdp2_config=fsdp2_config)
5857

5958
# Initialization checks
6059
assert len(model.mlp._forward_pre_hooks) == 1, 'Expected 1 forward pre-hook on the mlp module'
6160
assert len(model.mlp.fc1._forward_pre_hooks) == 0, 'Expected 0 forward pre-hook on the fc1 module'
6261
assert len(model.mlp.fc2._forward_pre_hooks) == 0, 'Expected 0 forward pre-hook on the fc2 module'
6362
assert len(model.module._forward_pre_hooks) == 1, 'Expected 1 forward pre-hook on the root module'
64-
65-
# Check that the weights are DTensor
6663
assert isinstance(model.mlp.fc1.weight, DTensor), 'mlp.fc1.weight should be a DTensor'
6764
assert isinstance(model.mlp.fc2.weight, DTensor), 'mlp.fc2.weight should be a DTensor'
68-
# Check that all optimizer parameters are DTensor (we only have one param group)
69-
assert len(optimizer.param_groups) == 1, 'Expected 1 param group in optimizer'
70-
assert len(optimizer.param_groups[0]['params']) >= 1, 'Expected at least 1 parameter in optimizer (depends on the model)'
71-
assert all(isinstance(param, DTensor) for param in optimizer.param_groups[0]['params']), 'All parameters in optimizer should be DTensor'
72-
# Validate that the ids of the parameters in the optimizer exist in the model
73-
model_param_ids = [id(p) for p in model.parameters()]
74-
for param in optimizer.param_groups[0]['params']:
75-
assert id(param) in model_param_ids, 'Parameter id in optimizer does not match parameter id in model'
7665

7766
if isinstance(model, PartialWeightTiedModel):
7867
assert len(model.fc3._forward_pre_hooks) == 1, 'Expected 1 forward pre-hook on the fc3 module'
@@ -98,3 +87,82 @@ def test_fsdp2_initialization_with_tied_params(
9887
weight_2 = model.mlp.fc2.weight.full_tensor()
9988
assert (model.mlp.fc1.weight is model.mlp.fc2.weight)
10089
assert (torch.equal(weight_1, weight_2))
90+
91+
92+
@pytest.mark.parametrize('case', ['all_params_one_group', 'subset_one_group', 'multiple_groups'])
93+
@pytest.mark.parametrize('device', _INIT_DEVICES)
94+
@world_size(2)
95+
@pytest.mark.gpu
96+
@pytest.mark.filterwarnings('ignore:FSDP2 Config/APIs are experimental*:UserWarning')
97+
@pytest.mark.skipif(SKIP_TEST, reason='FSDP2 is not available in torch < 2.6.0')
98+
def test_fsdp2_optimizer_handling(
99+
case: str,
100+
world_size: int,
101+
device: str,
102+
):
103+
"""Test FSDP2 correctly updates optimizer state for various configurations."""
104+
del world_size
105+
106+
num_classes = 10
107+
model = PartialWeightTiedModel(num_features=num_classes, device=device)
108+
dataset = RandomClassificationDataset(shape=(num_classes,), size=10, num_classes=num_classes)
109+
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
110+
111+
all_params_list = list(model.parameters())
112+
fc1_params_list = list(model.mlp.fc1.parameters())
113+
fc3_params_list = list(model.fc3.parameters())
114+
remaining_params_list = [
115+
p for p in all_params_list
116+
if not any(p is param for param in fc1_params_list)
117+
and not any(p is param for param in fc3_params_list)
118+
and not any(p is param for param in model.mlp.fc2.parameters()) # To avoid double counting the tied parameter
119+
]
120+
121+
if case == 'all_params_one_group':
122+
optimizer_input = [{'params': all_params_list, 'lr': 0.01}]
123+
elif case == 'subset_one_group':
124+
optimizer_input = [{'params': fc1_params_list, 'lr': 0.02}]
125+
elif case == 'multiple_groups':
126+
optimizer_input = [
127+
{'params': fc1_params_list, 'lr': 0.01},
128+
{'params': fc3_params_list, 'lr': 0.02},
129+
{'params': remaining_params_list, 'lr': 0.03},
130+
]
131+
132+
optimizer = torch.optim.Adam(optimizer_input)
133+
134+
fsdp2_config = FSDP2Config(
135+
device_mesh=None,
136+
reshard_after_forward=True,
137+
mp_policy=None,
138+
offload_policy=None,
139+
)
140+
prepare_fully_shard(model=model.module, optimizers=optimizer, fsdp2_config=fsdp2_config)
141+
142+
def validate_optimizer_state(current_optimizer: torch.optim.Optimizer, stage: str):
143+
assert len(current_optimizer.param_groups) == len(optimizer_input), \
144+
f"[{case}/{stage}] Group count mismatch. Expected {len(optimizer_input)}, Got {len(current_optimizer.param_groups)}"
145+
for i, group in enumerate(current_optimizer.param_groups):
146+
opt_params = group['params']
147+
assert len(opt_params) == len(optimizer_input[i]['params']), \
148+
f"[{case}/{stage}] Group {i}: Param count mismatch. Expected {len(optimizer_input[i]['params'])}, Got {len(opt_params)}"
149+
assert all(isinstance(p, DTensor) for p in opt_params), \
150+
f"[{case}/{stage}] Group {i}: Not all parameters are DTensors"
151+
assert group['lr'] == optimizer_input[i]['lr'], \
152+
f"[{case}/{stage}] Group {i}: LR mismatch. Expected {optimizer_input[i]['lr']}, Got {group['lr']}"
153+
154+
validate_optimizer_state(optimizer, stage="after_fully_shard")
155+
156+
model.to_empty(device='cuda')
157+
for module in model.modules():
158+
model.param_init_fn(module)
159+
160+
trainer = Trainer(
161+
model=model,
162+
optimizers=optimizer,
163+
train_dataloader=dataloader,
164+
max_duration='1ba',
165+
)
166+
trainer.fit()
167+
168+
validate_optimizer_state(optimizer, stage="after_fit")

0 commit comments

Comments
 (0)