Skip to content

Commit 29e7593

Browse files
authored
Update optimizer params for fsdp2 (#3822)
1 parent 635d9a7 commit 29e7593

File tree

3 files changed

+188
-4
lines changed

3 files changed

+188
-4
lines changed

composer/distributed/fsdp2.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
"""Helpers for FSDP2."""
55

6+
import warnings
7+
from typing import Optional
8+
9+
import torch
610
import torch.nn as nn
711
from torch.distributed.fsdp._fully_shard import fully_shard
812

@@ -91,6 +95,77 @@ def _check_param_sharing(module: nn.Module):
9195
_check_param_sharing(model)
9296

9397

98+
def update_optimizer_modules(
99+
optimizer: torch.optim.Optimizer,
100+
model: nn.Module,
101+
orig_param_to_name: dict[torch.nn.Parameter, str],
102+
) -> None:
103+
"""Updates the optimizer's parameter groups to use the sharded model parameters.
104+
105+
Assumes no training has occurred yet and the optimizer state is empty. If the optimizer state is not empty,
106+
it will be cleared with a warning.
107+
108+
Args:
109+
optimizer (Optimizer): The optimizer to update.
110+
model (nn.Module): The parent model that is also sharded.
111+
orig_param_to_name (dict[torch.nn.Parameter, str]): Mapping from original parameters to their names.
112+
"""
113+
# Check if the optimizer state is empty
114+
# If not, clear it and warn the user
115+
if optimizer.state:
116+
warnings.warn(
117+
'FSDP2 wrapping assumes the optimizer state is empty (i.e., training has not started). '
118+
'but non-empty optimizer state was found. Optimizer state will be cleared.',
119+
)
120+
optimizer.state.clear()
121+
122+
# Build a mapping from parameter name to sharded parameter (after sharding)
123+
name_to_sharded_param = dict(model.named_parameters(recurse=True))
124+
125+
# Create a mapping from old parameters to new DTensor parameters
126+
# Note: if params are tied and the same parameter is in multiple groups, pytorch will raise an error
127+
old_to_new_param = {}
128+
unseen_params = set()
129+
for group in optimizer.param_groups:
130+
for param in group['params']:
131+
# Note: the names of the parameters stay the same after sharding so we can do the following.
132+
param_name = orig_param_to_name.get(param, None)
133+
if param_name is None:
134+
# This means that the parameter is not in the original model
135+
# And as `prepare_fully_shard` takes in the optimizer itself, we don't have a way to
136+
# identify the parameter name so we just use the id
137+
unseen_params.add(f'optimizer.param_id.{id(param)}')
138+
elif param_name not in name_to_sharded_param:
139+
# This means that the base model parameter is not in the sharded model
140+
# This should never happen, we note this in the error message
141+
unseen_params.add(f'model.param_name.{param_name}')
142+
else:
143+
old_to_new_param[param] = name_to_sharded_param[param_name]
144+
145+
# Raise an error with all the parameters that were not found in the sharded model
146+
if len(unseen_params) > 0:
147+
raise ValueError(
148+
f'The same model must be passed to the optimizer and trainer but the '
149+
f'following parameters were not found in the sharded model: {list(unseen_params)}.'
150+
'All parameters prefixed with "optimizer.param_id" imply that the optimizer has the wrong model.'
151+
'All parameters prefixed with "model.param_name" imply a significant issue where sharding '
152+
'has not been applied correctly.',
153+
)
154+
155+
# Update param groups with new parameters
156+
new_param_groups = []
157+
for group in optimizer.param_groups:
158+
new_group = {k: v for k, v in group.items() if k != 'params'}
159+
new_params = [old_to_new_param[param] for param in group['params']]
160+
new_group['params'] = new_params
161+
new_param_groups.append(new_group)
162+
163+
# Update param groups
164+
optimizer.param_groups.clear()
165+
for group in new_param_groups:
166+
optimizer.add_param_group(group)
167+
168+
94169
def apply_fully_shard(
95170
model: nn.Module,
96171
independent_submodules: list[nn.Module],
@@ -147,6 +222,7 @@ def apply_fully_shard(
147222

148223
def prepare_fully_shard(
149224
model: nn.Module,
225+
optimizer: Optional[torch.optim.Optimizer],
150226
fsdp2_config: FSDP2Config,
151227
) -> None:
152228
"""Applies FSDP2's `fully_shard` to the model according to given fsdp2_config.
@@ -158,5 +234,14 @@ def prepare_fully_shard(
158234
Returns:
159235
None
160236
"""
237+
# Build the parameter to name mapping
238+
orig_param_to_name = {p: n for n, p in model.named_parameters(recurse=True)}
239+
240+
# Get the modules to shard
161241
modules_to_shard, _ = get_standalone_and_tied_modules(list(model.children()))
242+
162243
apply_fully_shard(model, modules_to_shard, fsdp2_config)
244+
245+
# If the optimizer is provided, update the optimizer's parameter groups to use the sharded model's DTensor parameters
246+
if optimizer is not None:
247+
update_optimizer_modules(optimizer, model, orig_param_to_name)

tests/trainer/test_fsdp2.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import pathlib
5+
from typing import Optional
56

67
import pytest
78
import torch
@@ -69,6 +70,7 @@ def create_trainer_with_model(
6970
num_classes: int = 10,
7071
max_duration: str = '10ep',
7172
use_fsdp2: bool = True,
73+
optimizer: Optional[torch.optim.Optimizer] = None,
7274
) -> Trainer:
7375
"""Helper function to create a Trainer with a model, dataloader, and FSDP2 configuration."""
7476
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
@@ -79,7 +81,7 @@ def create_trainer_with_model(
7981
# Trainer is not calling prepare_fully_shard yet, so we need to do it manually
8082
fsdp2_config = FSDP2Config()
8183
# NOTE we can only apply FSDP2 to ComposerClassifier's module field until we support auto_wrap
82-
prepare_fully_shard(model=model.module, fsdp2_config=fsdp2_config)
84+
prepare_fully_shard(model=model.module, fsdp2_config=fsdp2_config, optimizer=optimizer)
8385
# NOTE module to_empty should only happen after the model is fully sharded and parameters are coverted to Dtensor
8486
# otherwise to_empty breaks weight tying
8587
# TODO (FSDP2) we should guardrail this in prepare_fully_shard
@@ -91,7 +93,8 @@ def create_trainer_with_model(
9193
parallelism_config.fsdp2 = fsdp2_config
9294
else:
9395
parallelism_config.fsdp = FSDPConfig(state_dict_type='sharded')
94-
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
96+
if optimizer is None:
97+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
9598
trainer = Trainer(
9699
model=model,
97100
optimizers=optimizer,
@@ -233,3 +236,99 @@ def test_fsdp2_load_from_fsdp1(
233236
fsdp1_param,
234237
param.full_tensor(),
235238
), f'Weights: {name} should be equal after loading, however one is {fsdp1_param} and the other is {param.full_tensor()}'
239+
240+
241+
@world_size(2)
242+
@pytest.mark.gpu
243+
@fsdp2_context
244+
@pytest.mark.parametrize('case', ['all_params_one_group', 'subset_one_group', 'multiple_groups'])
245+
@pytest.mark.parametrize('device', _INIT_DEVICES)
246+
def test_fsdp2_optimizer_handling(
247+
world_size: int,
248+
case: str,
249+
device: str,
250+
):
251+
"""Test FSDP2 correctly updates optimizer state for various configurations."""
252+
del world_size
253+
254+
NUM_FEATURES = 10
255+
NUM_CLASSES = 10
256+
model = PartialWeightTiedModel(num_features=NUM_FEATURES, device=device)
257+
258+
all_params_list = list(model.parameters())
259+
fc1_params_list = list(model.mlp.fc1.parameters())
260+
fc3_params_list = list(model.fc3.parameters())
261+
262+
if case == 'all_params_one_group':
263+
optimizer_input = [{'params': all_params_list, 'lr': 0.01}]
264+
elif case == 'subset_one_group':
265+
optimizer_input = [{'params': fc1_params_list, 'lr': 0.02}] # Same as fc2_params_list (since tied weights)
266+
elif case == 'multiple_groups':
267+
optimizer_input = [
268+
{
269+
'params': fc1_params_list,
270+
'lr': 0.01,
271+
}, # Same as fc2_params_list (since tied weights)
272+
{
273+
'params': fc3_params_list,
274+
'lr': 0.02,
275+
},
276+
]
277+
else:
278+
raise ValueError(f'Invalid case: {case}')
279+
280+
optimizer = torch.optim.Adam(optimizer_input)
281+
trainer = create_trainer_with_model(model=model, num_classes=NUM_CLASSES, use_fsdp2=True, optimizer=optimizer)
282+
283+
def validate_optimizer_state(current_optimizer: torch.optim.Optimizer, stage: str):
284+
assert len(current_optimizer.param_groups) == len(optimizer_input), \
285+
f'[{case}/{stage}] Group count mismatch. Expected {len(optimizer_input)}, Got {len(current_optimizer.param_groups)}'
286+
for i, group in enumerate(current_optimizer.param_groups):
287+
opt_params = group['params']
288+
# Check that the number of parameters in the optimizer group matches the number of parameters in the input
289+
assert len(opt_params) == len(optimizer_input[i]['params']), \
290+
f"[{case}/{stage}] Group {i}: Param count mismatch. Expected {len(optimizer_input[i]['params'])}, Got {len(opt_params)}"
291+
292+
# Check that all parameters are DTensor
293+
assert all(isinstance(p, DTensor) for p in opt_params), \
294+
f'[{case}/{stage}] Group {i}: Not all parameters are DTensors'
295+
296+
# Check that all keys match between input and current groups
297+
input_keys = set(optimizer_input[i].keys())
298+
group_keys = set(group.keys())
299+
assert input_keys == group_keys, \
300+
f'[{case}/{stage}] Group {i}: Key mismatch. Expected {input_keys}, Got {group_keys}'
301+
302+
# Check values for all keys
303+
for key in input_keys:
304+
if key != 'params':
305+
assert group[key] == optimizer_input[i][key], \
306+
f'[{case}/{stage}] Group {i}: {key} mismatch. Expected {optimizer_input[i][key]}, Got {group[key]}'
307+
308+
# Validate optimizer state after sharding and before training
309+
validate_optimizer_state(optimizer, stage='after_fully_shard')
310+
311+
trainer.fit()
312+
313+
# Validate optimizer state after training
314+
validate_optimizer_state(optimizer, stage='after_fit')
315+
316+
317+
@world_size(2)
318+
@pytest.mark.gpu
319+
@fsdp2_context
320+
def test_fsdp2_optimizer_raises_error_when_optimizer_modules_dont_match(world_size: int,):
321+
"""Test FSDP2 raises an error when the optimizer modules don't match the model modules."""
322+
del world_size
323+
324+
NUM_FEATURES = 10
325+
NUM_CLASSES = 10
326+
model = SimpleComposerMLP(num_features=NUM_FEATURES, device='cuda', num_classes=NUM_CLASSES)
327+
other_model = SimpleWeightTiedModel(num_features=NUM_FEATURES, device='cuda')
328+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
329+
with pytest.raises(ValueError) as e:
330+
create_trainer_with_model(model=other_model, num_classes=NUM_CLASSES, use_fsdp2=True, optimizer=optimizer)
331+
# Check that error message uses the correct prefix implying optimizer difference
332+
# We check with `optimizer.param_id.` (with the period) since `optimizer.param_id` exists
333+
# by default in the error message's legend
334+
assert 'optimizer.param_id.' in str(e.value)

tests/trainer/test_fsdp2_gradscaler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def test_fsdp2_with_gradscaler_inf(world_size: int):
4343
dtype = torch.float16
4444

4545
model = SimpleModel().to('cuda')
46-
# Apply fully_shard to the model
47-
prepare_fully_shard(model, FSDP2Config())
4846
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
47+
# Apply fully_shard to the model
48+
prepare_fully_shard(model, optimizer, FSDP2Config())
4949

5050
# dummy inputs and targets
5151
inputs = torch.randn(1, 2, device='cuda', dtype=dtype)

0 commit comments

Comments
 (0)