Skip to content

Commit 02e818e

Browse files
committed
formatting fixes + additional comments
1 parent 75c7617 commit 02e818e

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

composer/distributed/fsdp2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from torch.distributed.fsdp.wrap import CustomPolicy
1212

1313
from composer.distributed.fsdp2_utils import (
14-
generate_default_policy,
1514
check_param_tying,
15+
generate_default_policy,
1616
get_standalone_and_tied_modules,
1717
legalize_param_sharing_between_modules,
1818
update_optimizer_modules,

composer/distributed/fsdp2_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
# FSDP2 Weight Tying Functions
1717
# TODO: These functions are all relatively similar to each other, we should consider
18-
# refactoring them in the future to be simpler.
18+
# refactoring them in the future to be simpler. We also might benefit from moving these
19+
# weight tying functions to a new file (in a potential `fsdp2_utils` directory).
1920

2021

2122
def legalize_param_sharing_between_modules(model: nn.Module, modules_to_shard: list[nn.Module]) -> None:
@@ -139,9 +140,8 @@ def _recursive_get_params(module: nn.Module, prefix: str = '') -> None:
139140

140141
_recursive_get_params(model)
141142

142-
# Filter to keep only groups where the same parameter object has multiple FQNs
143-
tying_groups = [fqns for fqns in param_object_to_fqns.values() if len(fqns) > 1]
144-
return tying_groups
143+
# Return a list of sets, each set contains the FQNs for a tied parameter group
144+
return list(param_object_to_fqns.values())
145145

146146

147147
@contextlib.contextmanager
@@ -245,7 +245,12 @@ def update_optimizer_modules(
245245

246246

247247
def generate_default_policy(parent_model: nn.Module) -> CustomPolicy:
248-
# The same policy as FSDP1 with some caveats around the parent_model (root_module)
248+
"""Generates the default fsdp wrap policy for FSDP2.
249+
250+
This policy is the same as the default policy in FSDP1 with some caveats around
251+
how the root_module (parent_model) is handled to best support FSDP2.
252+
"""
253+
249254
def lambda_fn(current_module: nn.Module) -> Union[bool, dict[str, Any]]:
250255
ret = False
251256
if hasattr(current_module, '_fsdp_wrap'):

0 commit comments

Comments
 (0)