2
2
# SPDX-License-Identifier: Apache-2.0
3
3
4
4
import pathlib
5
+ from typing import Optional
5
6
6
7
import pytest
7
8
import torch
@@ -69,6 +70,7 @@ def create_trainer_with_model(
69
70
num_classes : int = 10 ,
70
71
max_duration : str = '10ep' ,
71
72
use_fsdp2 : bool = True ,
73
+ optimizer : Optional [torch .optim .Optimizer ] = None ,
72
74
) -> Trainer :
73
75
"""Helper function to create a Trainer with a model, dataloader, and FSDP2 configuration."""
74
76
dataset = RandomClassificationDataset (shape = (num_classes ,), size = 2 , num_classes = num_classes )
@@ -79,7 +81,7 @@ def create_trainer_with_model(
79
81
# Trainer is not calling prepare_fully_shard yet, so we need to do it manually
80
82
fsdp2_config = FSDP2Config ()
81
83
# NOTE we can only apply FSDP2 to ComposerClassifier's module field until we support auto_wrap
82
- prepare_fully_shard (model = model .module , fsdp2_config = fsdp2_config )
84
+ prepare_fully_shard (model = model .module , fsdp2_config = fsdp2_config , optimizer = optimizer )
83
85
# NOTE module to_empty should only happen after the model is fully sharded and parameters are coverted to Dtensor
84
86
# otherwise to_empty breaks weight tying
85
87
# TODO (FSDP2) we should guardrail this in prepare_fully_shard
@@ -91,7 +93,8 @@ def create_trainer_with_model(
91
93
parallelism_config .fsdp2 = fsdp2_config
92
94
else :
93
95
parallelism_config .fsdp = FSDPConfig (state_dict_type = 'sharded' )
94
- optimizer = torch .optim .Adam (model .parameters (), lr = 0.1 )
96
+ if optimizer is None :
97
+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.1 )
95
98
trainer = Trainer (
96
99
model = model ,
97
100
optimizers = optimizer ,
@@ -233,3 +236,99 @@ def test_fsdp2_load_from_fsdp1(
233
236
fsdp1_param ,
234
237
param .full_tensor (),
235
238
), f'Weights: { name } should be equal after loading, however one is { fsdp1_param } and the other is { param .full_tensor ()} '
239
+
240
+
241
+ @world_size (2 )
242
+ @pytest .mark .gpu
243
+ @fsdp2_context
244
+ @pytest .mark .parametrize ('case' , ['all_params_one_group' , 'subset_one_group' , 'multiple_groups' ])
245
+ @pytest .mark .parametrize ('device' , _INIT_DEVICES )
246
+ def test_fsdp2_optimizer_handling (
247
+ world_size : int ,
248
+ case : str ,
249
+ device : str ,
250
+ ):
251
+ """Test FSDP2 correctly updates optimizer state for various configurations."""
252
+ del world_size
253
+
254
+ NUM_FEATURES = 10
255
+ NUM_CLASSES = 10
256
+ model = PartialWeightTiedModel (num_features = NUM_FEATURES , device = device )
257
+
258
+ all_params_list = list (model .parameters ())
259
+ fc1_params_list = list (model .mlp .fc1 .parameters ())
260
+ fc3_params_list = list (model .fc3 .parameters ())
261
+
262
+ if case == 'all_params_one_group' :
263
+ optimizer_input = [{'params' : all_params_list , 'lr' : 0.01 }]
264
+ elif case == 'subset_one_group' :
265
+ optimizer_input = [{'params' : fc1_params_list , 'lr' : 0.02 }] # Same as fc2_params_list (since tied weights)
266
+ elif case == 'multiple_groups' :
267
+ optimizer_input = [
268
+ {
269
+ 'params' : fc1_params_list ,
270
+ 'lr' : 0.01 ,
271
+ }, # Same as fc2_params_list (since tied weights)
272
+ {
273
+ 'params' : fc3_params_list ,
274
+ 'lr' : 0.02 ,
275
+ },
276
+ ]
277
+ else :
278
+ raise ValueError (f'Invalid case: { case } ' )
279
+
280
+ optimizer = torch .optim .Adam (optimizer_input )
281
+ trainer = create_trainer_with_model (model = model , num_classes = NUM_CLASSES , use_fsdp2 = True , optimizer = optimizer )
282
+
283
+ def validate_optimizer_state (current_optimizer : torch .optim .Optimizer , stage : str ):
284
+ assert len (current_optimizer .param_groups ) == len (optimizer_input ), \
285
+ f'[{ case } /{ stage } ] Group count mismatch. Expected { len (optimizer_input )} , Got { len (current_optimizer .param_groups )} '
286
+ for i , group in enumerate (current_optimizer .param_groups ):
287
+ opt_params = group ['params' ]
288
+ # Check that the number of parameters in the optimizer group matches the number of parameters in the input
289
+ assert len (opt_params ) == len (optimizer_input [i ]['params' ]), \
290
+ f"[{ case } /{ stage } ] Group { i } : Param count mismatch. Expected { len (optimizer_input [i ]['params' ])} , Got { len (opt_params )} "
291
+
292
+ # Check that all parameters are DTensor
293
+ assert all (isinstance (p , DTensor ) for p in opt_params ), \
294
+ f'[{ case } /{ stage } ] Group { i } : Not all parameters are DTensors'
295
+
296
+ # Check that all keys match between input and current groups
297
+ input_keys = set (optimizer_input [i ].keys ())
298
+ group_keys = set (group .keys ())
299
+ assert input_keys == group_keys , \
300
+ f'[{ case } /{ stage } ] Group { i } : Key mismatch. Expected { input_keys } , Got { group_keys } '
301
+
302
+ # Check values for all keys
303
+ for key in input_keys :
304
+ if key != 'params' :
305
+ assert group [key ] == optimizer_input [i ][key ], \
306
+ f'[{ case } /{ stage } ] Group { i } : { key } mismatch. Expected { optimizer_input [i ][key ]} , Got { group [key ]} '
307
+
308
+ # Validate optimizer state after sharding and before training
309
+ validate_optimizer_state (optimizer , stage = 'after_fully_shard' )
310
+
311
+ trainer .fit ()
312
+
313
+ # Validate optimizer state after training
314
+ validate_optimizer_state (optimizer , stage = 'after_fit' )
315
+
316
+
317
+ @world_size (2 )
318
+ @pytest .mark .gpu
319
+ @fsdp2_context
320
+ def test_fsdp2_optimizer_raises_error_when_optimizer_modules_dont_match (world_size : int ,):
321
+ """Test FSDP2 raises an error when the optimizer modules don't match the model modules."""
322
+ del world_size
323
+
324
+ NUM_FEATURES = 10
325
+ NUM_CLASSES = 10
326
+ model = SimpleComposerMLP (num_features = NUM_FEATURES , device = 'cuda' , num_classes = NUM_CLASSES )
327
+ other_model = SimpleWeightTiedModel (num_features = NUM_FEATURES , device = 'cuda' )
328
+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.01 )
329
+ with pytest .raises (ValueError ) as e :
330
+ create_trainer_with_model (model = other_model , num_classes = NUM_CLASSES , use_fsdp2 = True , optimizer = optimizer )
331
+ # Check that error message uses the correct prefix implying optimizer difference
332
+ # We check with `optimizer.param_id.` (with the period) since `optimizer.param_id` exists
333
+ # by default in the error message's legend
334
+ assert 'optimizer.param_id.' in str (e .value )
0 commit comments