8
8
from typing import Optional , Union
9
9
10
10
from torch import nn
11
+ from torch .optim import Optimizer
11
12
from torch .distributed ._tensor .device_mesh import DeviceMesh
12
13
from torch .distributed .fsdp ._fully_shard import fully_shard
13
14
from torch .distributed .fsdp ._fully_shard ._fsdp_api import MixedPrecisionPolicy , OffloadPolicy
@@ -119,6 +120,50 @@ def _check_param_sharing(module: nn.Module):
119
120
# Start the check from the root model
120
121
_check_param_sharing (model )
121
122
123
+ def update_optimizer_modules (
124
+ optimizer : Optimizer ,
125
+ modules_to_shard : list [nn .Module ],
126
+ model : nn .Module ,
127
+ orig_param_id_to_name : dict [int , str ],
128
+ ) -> None :
129
+ """Updates the optimizer's parameter groups to use the sharded model parameters.
130
+ Assumes no training has occurred yet and the optimizer state is empty.
131
+
132
+ Args:
133
+ optimizer (Optimizer): The optimizer to update.
134
+ modules_to_shard (list[nn.Module]): The modules that will be sharded.
135
+ model (nn.Module): The parent model that is also sharded.
136
+ orig_param_id_to_name (dict[int, str]): Mapping from original parameter IDs to their names.
137
+ """
138
+ # Build a mapping from parameter name to sharded parameter (after sharding)
139
+ name_to_sharded_param = dict (model .named_parameters ())
140
+ for module in modules_to_shard :
141
+ name_to_sharded_param .update (dict (module .named_parameters ()))
142
+
143
+ # Create a mapping from old parameters to new DTensor parameters
144
+ old_to_new_param = {}
145
+ for group in optimizer .param_groups :
146
+ for param in group ['params' ]:
147
+ param_name = orig_param_id_to_name .get (id (param ))
148
+ if param_name is not None and param_name in name_to_sharded_param :
149
+ old_to_new_param [param ] = name_to_sharded_param [param_name ]
150
+ else :
151
+ # TODO: Look into whether we will ever hit this case...
152
+ raise ValueError (f"Parameter { param } not found in model" )
153
+
154
+ # Update param groups with new parameters
155
+ new_param_groups = []
156
+ for group in optimizer .param_groups :
157
+ new_group = {k : v for k , v in group .items () if k != 'params' }
158
+ new_params = [old_to_new_param [param ] for param in group ['params' ]]
159
+ new_group ['params' ] = new_params
160
+ new_param_groups .append (new_group )
161
+
162
+ # Update param groups
163
+ optimizer .param_groups .clear ()
164
+ for group in new_param_groups :
165
+ optimizer .add_param_group (group )
166
+
122
167
123
168
def apply_fully_shard (
124
169
model : nn .Module ,
@@ -178,6 +223,7 @@ def apply_fully_shard(
178
223
179
224
def prepare_fully_shard (
180
225
model : nn .Module ,
226
+ optimizer : Optional [Optimizer ],
181
227
fsdp2_config : FSDP2Config ,
182
228
) -> None :
183
229
"""Applies FSDP2's `fully_shard` to the model according to given fsdp2_config.
@@ -190,4 +236,13 @@ def prepare_fully_shard(
190
236
None
191
237
"""
192
238
modules_to_shard , _ = get_standalone_and_tied_modules (list (model .children ()))
239
+
240
+ # Build the parameter ID to name mapping (with no duplicates)
241
+ orig_param_id_to_name = {id (param ): name for name , param in model .named_parameters ()}
242
+ for module in modules_to_shard :
243
+ orig_param_id_to_name .update ({id (param ): name for name , param in module .named_parameters ()})
244
+
193
245
apply_fully_shard (model , modules_to_shard , fsdp2_config )
246
+
247
+ # After the model is sharded in place, we can update the optimizer state to use the DTensor parameters
248
+ update_optimizer_modules (optimizer , modules_to_shard , model , orig_param_id_to_name )
0 commit comments