Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions composer/distributed/fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

"""Helpers for FSDP2."""

import warnings
from typing import Optional

import torch
import torch.nn as nn
from torch.distributed.fsdp._fully_shard import fully_shard

Expand Down Expand Up @@ -91,6 +95,78 @@ def _check_param_sharing(module: nn.Module):
_check_param_sharing(model)


def update_optimizer_modules(
optimizer: torch.optim.Optimizer,
model: nn.Module,
orig_param_to_name: dict[torch.nn.Parameter, str],
) -> None:
"""Updates the optimizer's parameter groups to use the sharded model parameters.

Assumes no training has occurred yet and the optimizer state is empty. If the optimizer state is not empty,
it will be cleared with a warning.

Args:
optimizer (Optimizer): The optimizer to update.
modules_to_shard (list[nn.Module]): The modules that will be sharded.
model (nn.Module): The parent model that is also sharded.
orig_param_to_name (dict[torch.nn.Parameter, str]): Mapping from original parameters to their names.
"""
# Check if the optimizer state is empty
# If not, clear it and warn the user
if optimizer.state:
warnings.warn(
'FSDP2 wrapping assumes the optimizer state is empty (i.e., training has not started). '
'but non-empty optimizer state was found. Optimizer state will be cleared.',
)
optimizer.state.clear()

# Build a mapping from parameter name to sharded parameter (after sharding)
name_to_sharded_param = dict(model.named_parameters(recurse=True))

# Create a mapping from old parameters to new DTensor parameters
# Note: if params are tied and the same parameter is in multiple groups, pytorch will raise an error
old_to_new_param = {}
unseen_params = set()
for group in optimizer.param_groups:
for param in group['params']:
# Note: the names of the parameters stay the same after sharding so we can do the following.
param_name = orig_param_to_name.get(param, None)
if param_name is None:
# This means that the parameter is not in the original model
# And as `prepare_fully_shard` takes in the optimizer itself, we don't have a way to
# identify the parameter name so we just use the id
unseen_params.add(f'optimizer.param_id.{id(param)}')
elif param_name not in name_to_sharded_param:
# This means that the base model parameter is not in the sharded model
# This should never happen, we note this in the error message
unseen_params.add(f'model.param_name.{param_name}')
else:
old_to_new_param[param] = name_to_sharded_param[param_name]

# Raise an error with all the parameters that were not found in the sharded model
if len(unseen_params) > 0:
raise ValueError(
f'The same model must be passed to the optimizer and trainer but the '
f'following parameters were not found in the sharded model: {list(unseen_params)}.'
'All parameters prefixed with "optimizer.param_id" imply that the optimizer has the wrong model.'
'All parameters prefixed with "model.param_name" imply a significant issue where sharding '
'has not been applied correctly.',
)

# Update param groups with new parameters
new_param_groups = []
for group in optimizer.param_groups:
new_group = {k: v for k, v in group.items() if k != 'params'}
new_params = [old_to_new_param[param] for param in group['params']]
new_group['params'] = new_params
new_param_groups.append(new_group)

# Update param groups
optimizer.param_groups.clear()
for group in new_param_groups:
optimizer.add_param_group(group)


def apply_fully_shard(
model: nn.Module,
independent_submodules: list[nn.Module],
Expand Down Expand Up @@ -147,6 +223,7 @@ def apply_fully_shard(

def prepare_fully_shard(
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer],
fsdp2_config: FSDP2Config,
) -> None:
"""Applies FSDP2's `fully_shard` to the model according to given fsdp2_config.
Expand All @@ -158,5 +235,14 @@ def prepare_fully_shard(
Returns:
None
"""
# Build the parameter to name mapping
orig_param_to_name = {p: n for n, p in model.named_parameters(recurse=True)}

# Get the modules to shard
modules_to_shard, _ = get_standalone_and_tied_modules(list(model.children()))

apply_fully_shard(model, modules_to_shard, fsdp2_config)

# If the optimizer is provided, update the optimizer's parameter groups to use the sharded model's DTensor parameters
if optimizer is not None:
update_optimizer_modules(optimizer, model, orig_param_to_name)
103 changes: 101 additions & 2 deletions tests/trainer/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import pathlib
from typing import Optional

import pytest
import torch
Expand Down Expand Up @@ -69,6 +70,7 @@ def create_trainer_with_model(
num_classes: int = 10,
max_duration: str = '10ep',
use_fsdp2: bool = True,
optimizer: Optional[torch.optim.Optimizer] = None,
) -> Trainer:
"""Helper function to create a Trainer with a model, dataloader, and FSDP2 configuration."""
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
Expand All @@ -79,7 +81,7 @@ def create_trainer_with_model(
# Trainer is not calling prepare_fully_shard yet, so we need to do it manually
fsdp2_config = FSDP2Config()
# NOTE we can only apply FSDP2 to ComposerClassifier's module field until we support auto_wrap
prepare_fully_shard(model=model.module, fsdp2_config=fsdp2_config)
prepare_fully_shard(model=model.module, fsdp2_config=fsdp2_config, optimizer=optimizer)
# NOTE module to_empty should only happen after the model is fully sharded and parameters are coverted to Dtensor
# otherwise to_empty breaks weight tying
# TODO (FSDP2) we should guardrail this in prepare_fully_shard
Expand All @@ -91,7 +93,8 @@ def create_trainer_with_model(
parallelism_config.fsdp2 = fsdp2_config
else:
parallelism_config.fsdp = FSDPConfig(state_dict_type='sharded')
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
if optimizer is None:
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
trainer = Trainer(
model=model,
optimizers=optimizer,
Expand Down Expand Up @@ -233,3 +236,99 @@ def test_fsdp2_load_from_fsdp1(
fsdp1_param,
param.full_tensor(),
), f'Weights: {name} should be equal after loading, however one is {fsdp1_param} and the other is {param.full_tensor()}'


@world_size(2)
@pytest.mark.gpu
@fsdp2_context
@pytest.mark.parametrize('case', ['all_params_one_group', 'subset_one_group', 'multiple_groups'])
@pytest.mark.parametrize('device', _INIT_DEVICES)
def test_fsdp2_optimizer_handling(
world_size: int,
case: str,
device: str,
):
"""Test FSDP2 correctly updates optimizer state for various configurations."""
del world_size

NUM_FEATURES = 10
NUM_CLASSES = 10
model = PartialWeightTiedModel(num_features=NUM_FEATURES, device=device)

all_params_list = list(model.parameters())
fc1_params_list = list(model.mlp.fc1.parameters())
fc3_params_list = list(model.fc3.parameters())

if case == 'all_params_one_group':
optimizer_input = [{'params': all_params_list, 'lr': 0.01}]
elif case == 'subset_one_group':
optimizer_input = [{'params': fc1_params_list, 'lr': 0.02}] # Same as fc2_params_list (since tied weights)
elif case == 'multiple_groups':
optimizer_input = [
{
'params': fc1_params_list,
'lr': 0.01,
}, # Same as fc2_params_list (since tied weights)
{
'params': fc3_params_list,
'lr': 0.02,
},
]
else:
raise ValueError(f'Invalid case: {case}')

optimizer = torch.optim.Adam(optimizer_input)
trainer = create_trainer_with_model(model=model, num_classes=NUM_CLASSES, use_fsdp2=True, optimizer=optimizer)

def validate_optimizer_state(current_optimizer: torch.optim.Optimizer, stage: str):
assert len(current_optimizer.param_groups) == len(optimizer_input), \
f'[{case}/{stage}] Group count mismatch. Expected {len(optimizer_input)}, Got {len(current_optimizer.param_groups)}'
for i, group in enumerate(current_optimizer.param_groups):
opt_params = group['params']
# Check that the number of parameters in the optimizer group matches the number of parameters in the input
assert len(opt_params) == len(optimizer_input[i]['params']), \
f"[{case}/{stage}] Group {i}: Param count mismatch. Expected {len(optimizer_input[i]['params'])}, Got {len(opt_params)}"

# Check that all parameters are DTensor
assert all(isinstance(p, DTensor) for p in opt_params), \
f'[{case}/{stage}] Group {i}: Not all parameters are DTensors'

# Check that all keys match between input and current groups
input_keys = set(optimizer_input[i].keys())
group_keys = set(group.keys())
assert input_keys == group_keys, \
f'[{case}/{stage}] Group {i}: Key mismatch. Expected {input_keys}, Got {group_keys}'

# Check values for all keys
for key in input_keys:
if key != 'params':
assert group[key] == optimizer_input[i][key], \
f'[{case}/{stage}] Group {i}: {key} mismatch. Expected {optimizer_input[i][key]}, Got {group[key]}'

# Validate optimizer state after sharding and before training
validate_optimizer_state(optimizer, stage='after_fully_shard')

trainer.fit()

# Validate optimizer state after training
validate_optimizer_state(optimizer, stage='after_fit')


@world_size(2)
@pytest.mark.gpu
@fsdp2_context
def test_fsdp2_optimizer_raises_error_when_optimizer_modules_dont_match(world_size: int,):
"""Test FSDP2 raises an error when the optimizer modules don't match the model modules."""
del world_size

NUM_FEATURES = 10
NUM_CLASSES = 10
model = SimpleComposerMLP(num_features=NUM_FEATURES, device='cuda', num_classes=NUM_CLASSES)
other_model = SimpleWeightTiedModel(num_features=NUM_FEATURES, device='cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
with pytest.raises(ValueError) as e:
create_trainer_with_model(model=other_model, num_classes=NUM_CLASSES, use_fsdp2=True, optimizer=optimizer)
# Check that error message uses the correct prefix implying optimizer difference
# We check with `optimizer.param_id.` (with the period) since `optimizer.param_id` exists
# by default in the error message's legend
assert 'optimizer.param_id.' in str(e.value)
4 changes: 2 additions & 2 deletions tests/trainer/test_fsdp2_gradscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def test_fsdp2_with_gradscaler_inf(world_size: int):
dtype = torch.float16

model = SimpleModel().to('cuda')
# Apply fully_shard to the model
prepare_fully_shard(model, FSDP2Config())
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Apply fully_shard to the model
prepare_fully_shard(model, optimizer, FSDP2Config())

# dummy inputs and targets
inputs = torch.randn(1, 2, device='cuda', dtype=dtype)
Expand Down
Loading