Skip to content

Commit 5b033b7

Browse files
committed
update optimizer code
1 parent 44345ae commit 5b033b7

File tree

2 files changed

+70
-3
lines changed

2 files changed

+70
-3
lines changed

composer/distributed/fsdp2.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Optional, Union
99

1010
from torch import nn
11+
from torch.optim import Optimizer
1112
from torch.distributed._tensor.device_mesh import DeviceMesh
1213
from torch.distributed.fsdp._fully_shard import fully_shard
1314
from torch.distributed.fsdp._fully_shard._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
@@ -119,6 +120,50 @@ def _check_param_sharing(module: nn.Module):
119120
# Start the check from the root model
120121
_check_param_sharing(model)
121122

123+
def update_optimizer_modules(
124+
optimizer: Optimizer,
125+
modules_to_shard: list[nn.Module],
126+
model: nn.Module,
127+
orig_param_id_to_name: dict[int, str],
128+
) -> None:
129+
"""Updates the optimizer's parameter groups to use the sharded model parameters.
130+
Assumes no training has occurred yet and the optimizer state is empty.
131+
132+
Args:
133+
optimizer (Optimizer): The optimizer to update.
134+
modules_to_shard (list[nn.Module]): The modules that will be sharded.
135+
model (nn.Module): The parent model that is also sharded.
136+
orig_param_id_to_name (dict[int, str]): Mapping from original parameter IDs to their names.
137+
"""
138+
# Build a mapping from parameter name to sharded parameter (after sharding)
139+
name_to_sharded_param = dict(model.named_parameters())
140+
for module in modules_to_shard:
141+
name_to_sharded_param.update(dict(module.named_parameters()))
142+
143+
# Create a mapping from old parameters to new DTensor parameters
144+
old_to_new_param = {}
145+
for group in optimizer.param_groups:
146+
for param in group['params']:
147+
param_name = orig_param_id_to_name.get(id(param))
148+
if param_name is not None and param_name in name_to_sharded_param:
149+
old_to_new_param[param] = name_to_sharded_param[param_name]
150+
else:
151+
# TODO: Look into whether we will ever hit this case...
152+
raise ValueError(f"Parameter {param} not found in model")
153+
154+
# Update param groups with new parameters
155+
new_param_groups = []
156+
for group in optimizer.param_groups:
157+
new_group = {k: v for k, v in group.items() if k != 'params'}
158+
new_params = [old_to_new_param[param] for param in group['params']]
159+
new_group['params'] = new_params
160+
new_param_groups.append(new_group)
161+
162+
# Update param groups
163+
optimizer.param_groups.clear()
164+
for group in new_param_groups:
165+
optimizer.add_param_group(group)
166+
122167

123168
def apply_fully_shard(
124169
model: nn.Module,
@@ -178,6 +223,7 @@ def apply_fully_shard(
178223

179224
def prepare_fully_shard(
180225
model: nn.Module,
226+
optimizer: Optional[Optimizer],
181227
fsdp2_config: FSDP2Config,
182228
) -> None:
183229
"""Applies FSDP2's `fully_shard` to the model according to given fsdp2_config.
@@ -190,4 +236,13 @@ def prepare_fully_shard(
190236
None
191237
"""
192238
modules_to_shard, _ = get_standalone_and_tied_modules(list(model.children()))
239+
240+
# Build the parameter ID to name mapping (with no duplicates)
241+
orig_param_id_to_name = {id(param): name for name, param in model.named_parameters()}
242+
for module in modules_to_shard:
243+
orig_param_id_to_name.update({id(param): name for name, param in module.named_parameters()})
244+
193245
apply_fully_shard(model, modules_to_shard, fsdp2_config)
246+
247+
# After the model is sharded in place, we can update the optimizer state to use the DTensor parameters
248+
update_optimizer_modules(optimizer, modules_to_shard, model, orig_param_id_to_name)

tests/trainer/test_fsdp2.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from packaging import version
77
from torch.utils.data import DataLoader
8-
98
from composer.models import ComposerClassifier
109
from composer.trainer.trainer import Trainer
1110
from composer.utils import dist
@@ -28,6 +27,7 @@
2827

2928
@pytest.mark.parametrize('model', [SimpleWeightTiedModel, PartialWeightTiedModel])
3029
@pytest.mark.parametrize('device', _INIT_DEVICES)
30+
@pytest.mark.parametrize('optimizer', [torch.optim.Adam, torch.optim.SGD])
3131
@world_size(2)
3232
@pytest.mark.gpu
3333
@pytest.mark.filterwarnings('ignore:FSDP2 Config/APIs are experimental*:UserWarning')
@@ -36,6 +36,7 @@ def test_fsdp2_initialization_with_tied_params(
3636
model: ComposerClassifier,
3737
world_size: int,
3838
device: str,
39+
optimizer: type[torch.optim.Optimizer],
3940
):
4041
"""test FSDP2 initialization for a simple model with weight tying and a model where two modules
4142
from separate submodules have weight tying applied.
@@ -52,15 +53,27 @@ def test_fsdp2_initialization_with_tied_params(
5253
mp_policy=None,
5354
offload_policy=None,
5455
)
55-
prepare_fully_shard(model=model.module, fsdp2_config=fsdp2_config)
56+
optimizer = optimizer(model.parameters(), lr=0.1)
57+
prepare_fully_shard(model=model.module, optimizer=optimizer, fsdp2_config=fsdp2_config)
5658

5759
# Initialization checks
5860
assert len(model.mlp._forward_pre_hooks) == 1, 'Expected 1 forward pre-hook on the mlp module'
5961
assert len(model.mlp.fc1._forward_pre_hooks) == 0, 'Expected 0 forward pre-hook on the fc1 module'
6062
assert len(model.mlp.fc2._forward_pre_hooks) == 0, 'Expected 0 forward pre-hook on the fc2 module'
6163
assert len(model.module._forward_pre_hooks) == 1, 'Expected 1 forward pre-hook on the root module'
64+
65+
# Check that the weights are DTensor
6266
assert isinstance(model.mlp.fc1.weight, DTensor), 'mlp.fc1.weight should be a DTensor'
6367
assert isinstance(model.mlp.fc2.weight, DTensor), 'mlp.fc2.weight should be a DTensor'
68+
# Check that all optimizer parameters are DTensor (we only have one param group)
69+
assert len(optimizer.param_groups) == 1, 'Expected 1 param group in optimizer'
70+
assert len(optimizer.param_groups[0]['params']) >= 1, 'Expected at least 1 parameter in optimizer (depends on the model)'
71+
assert all(isinstance(param, DTensor) for param in optimizer.param_groups[0]['params']), 'All parameters in optimizer should be DTensor'
72+
# Validate that the ids of the parameters in the optimizer exist in the model
73+
model_param_ids = [id(p) for p in model.parameters()]
74+
for param in optimizer.param_groups[0]['params']:
75+
assert id(param) in model_param_ids, 'Parameter id in optimizer does not match parameter id in model'
76+
6477
if isinstance(model, PartialWeightTiedModel):
6578
assert len(model.fc3._forward_pre_hooks) == 1, 'Expected 1 forward pre-hook on the fc3 module'
6679
assert model.mlp.fc1.weight.size(0) == model.mlp.fc2.weight.to_local(
@@ -72,7 +85,6 @@ def test_fsdp2_initialization_with_tied_params(
7285
for module in model.modules():
7386
model.param_init_fn(module)
7487

75-
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
7688
trainer = Trainer(
7789
model=model,
7890
optimizers=optimizer,

0 commit comments

Comments
 (0)