|
15 | 15 |
|
16 | 16 | # FSDP2 Weight Tying Functions
|
17 | 17 | # TODO: These functions are all relatively similar to each other, we should consider
|
18 |
| -# refactoring them in the future to be simpler. |
| 18 | +# refactoring them in the future to be simpler. We also might benefit from moving these |
| 19 | +# weight tying functions to a new file (in a potential `fsdp2_utils` directory). |
19 | 20 |
|
20 | 21 |
|
21 | 22 | def legalize_param_sharing_between_modules(model: nn.Module, modules_to_shard: list[nn.Module]) -> None:
|
@@ -139,9 +140,8 @@ def _recursive_get_params(module: nn.Module, prefix: str = '') -> None:
|
139 | 140 |
|
140 | 141 | _recursive_get_params(model)
|
141 | 142 |
|
142 |
| - # Filter to keep only groups where the same parameter object has multiple FQNs |
143 |
| - tying_groups = [fqns for fqns in param_object_to_fqns.values() if len(fqns) > 1] |
144 |
| - return tying_groups |
| 143 | + # Return a list of sets, each set contains the FQNs for a tied parameter group |
| 144 | + return list(param_object_to_fqns.values()) |
145 | 145 |
|
146 | 146 |
|
147 | 147 | @contextlib.contextmanager
|
@@ -245,7 +245,12 @@ def update_optimizer_modules(
|
245 | 245 |
|
246 | 246 |
|
247 | 247 | def generate_default_policy(parent_model: nn.Module) -> CustomPolicy:
|
248 |
| - # The same policy as FSDP1 with some caveats around the parent_model (root_module) |
| 248 | + """Generates the default fsdp wrap policy for FSDP2. |
| 249 | +
|
| 250 | + This policy is the same as the default policy in FSDP1 with some caveats around |
| 251 | + how the root_module (parent_model) is handled to best support FSDP2. |
| 252 | + """ |
| 253 | + |
249 | 254 | def lambda_fn(current_module: nn.Module) -> Union[bool, dict[str, Any]]:
|
250 | 255 | ret = False
|
251 | 256 | if hasattr(current_module, '_fsdp_wrap'):
|
|
0 commit comments