@@ -33,7 +33,7 @@ def create_trainer_with_model(
33
33
activation_checkpointing : bool = False ,
34
34
activation_cpu_offload : bool = False ,
35
35
auto_microbatching : bool = False ,
36
- sync_module_states : bool = False ,
36
+ fsdp1_sync_module_states : bool = False ,
37
37
) -> Trainer :
38
38
"""Helper function to create a Trainer with a model, dataloader, and FSDP2 configuration."""
39
39
dataset = RandomClassificationDataset (shape = (num_classes ,), size = 2 , num_classes = num_classes )
@@ -44,12 +44,11 @@ def create_trainer_with_model(
44
44
parallelism_config .fsdp2 = FSDP2Config (
45
45
activation_checkpointing = activation_checkpointing ,
46
46
activation_cpu_offload = activation_cpu_offload ,
47
- sync_module_states = sync_module_states ,
48
47
)
49
48
else :
50
49
parallelism_config .fsdp = FSDPConfig (
51
50
state_dict_type = 'sharded' ,
52
- sync_module_states = sync_module_states ,
51
+ sync_module_states = fsdp1_sync_module_states ,
53
52
)
54
53
if optimizer is None :
55
54
optimizer = torch .optim .Adam (model .parameters (), lr = 0.1 )
@@ -346,6 +345,7 @@ def _create_model_with_mixed_init(model_class: type, num_features: int, device:
346
345
resolved_device = device if dist .get_local_rank () == 0 else 'meta'
347
346
348
347
# set the bias to be True for SimpleComposerMLP
348
+ # which is used for a later test
349
349
kwargs = {}
350
350
if model_class == SimpleComposerMLP :
351
351
kwargs ['add_bias' ] = True
@@ -360,14 +360,17 @@ def _create_model_with_mixed_init(model_class: type, num_features: int, device:
360
360
361
361
@staticmethod
362
362
def _train_model_and_extract_weights (model , use_fsdp2 : bool ):
363
- """Helper function to train a model and extract its weights."""
363
+ """Helper function to train a mixed init model and extract its weights."""
364
364
optimizer = torch .optim .SGD (model .parameters (), lr = 1e-4 )
365
+ kwargs = {}
366
+ if not use_fsdp2 :
367
+ kwargs ['fsdp1_sync_module_states' ] = True
365
368
trainer = create_trainer_with_model (
366
369
model = model ,
367
370
max_duration = f'10ba' ,
368
371
use_fsdp2 = use_fsdp2 ,
369
372
optimizer = optimizer ,
370
- sync_module_states = True ,
373
+ ** kwargs ,
371
374
)
372
375
373
376
trainer .fit ()
@@ -403,31 +406,6 @@ def _compare_weights(fsdp1_weights: dict, fsdp2_weights: dict, tolerance: float
403
406
assert diff < tolerance , \
404
407
f'Weight difference for { name } exceeds tolerance: { diff } > { tolerance } .'
405
408
406
- @world_size (2 )
407
- @pytest .mark .gpu
408
- # Note that we are testing on a GPU instance just to make sure we can initialize
409
- # on CPU and then move to GPU.
410
- @pytest .mark .parametrize ('device' , ['cuda' , 'cpu' ])
411
- def test_fsdp2_sync_module_states_raises_error_when_invalid_mixed_initialization (
412
- self ,
413
- world_size : int ,
414
- device : str ,
415
- ):
416
- """Test that FSDP2 sync_module_states raises an error when invalid mixed initialization is detected."""
417
- del world_size
418
- resolved_device = device if dist .get_local_rank () == 0 else 'meta'
419
- model = self ._create_model_with_mixed_init (SimpleComposerMLP , 10 , resolved_device )
420
- with pytest .raises (ValueError ) as e :
421
- create_trainer_with_model (model = model , num_classes = 10 , use_fsdp2 = True , sync_module_states = False )
422
- assert 'Detected mixed initialization where some ranks have model on cpu or gpu and some ranks are on meta' in str (
423
- e .value ,
424
- )
425
-
426
- model_2 = self ._create_model_with_mixed_init (SimpleComposerMLP , 10 , 'meta' )
427
- with pytest .raises (ValueError ) as e :
428
- create_trainer_with_model (model = model_2 , num_classes = 10 , use_fsdp2 = True , sync_module_states = True )
429
- assert 'Detected that all ranks (including rank 0) have model on meta' in str (e .value )
430
-
431
409
@world_size (2 )
432
410
@pytest .mark .gpu
433
411
# Note that we are testing on a GPU instance just to make sure we can initialize
@@ -458,8 +436,8 @@ def param_init_fn(module: torch.nn.Module):
458
436
model = model ,
459
437
num_classes = 10 ,
460
438
use_fsdp2 = True ,
461
- sync_module_states = True ,
462
439
)
440
+ assert trainer .state .fsdp_config .sync_module_states , 'sync_module_states should be True' # type: ignore
463
441
464
442
module = trainer .state .model .module # type: ignore
465
443
assert torch .equal (
@@ -500,9 +478,8 @@ def test_fsdp2_mixed_init_does_not_break_weight_tying(
500
478
model = model ,
501
479
num_classes = 10 ,
502
480
use_fsdp2 = True ,
503
- sync_module_states = True ,
504
481
)
505
-
482
+ assert trainer . state . fsdp_config . sync_module_states , 'sync_module_states should be True' # type: ignore
506
483
# Check that the weights are correctly tied after training
507
484
trainer .fit ()
508
485
weight_1 = model .mlp .fc1 .weight .full_tensor () # type: ignore
@@ -532,8 +509,9 @@ def test_fsdp2_sync_module_state_aligns_with_optimizer_state(
532
509
num_classes = 10 ,
533
510
use_fsdp2 = True ,
534
511
optimizer = optimizer ,
535
- sync_module_states = True ,
536
512
)
513
+
514
+ assert trainer .state .fsdp_config .sync_module_states , 'sync_module_states should be True' # type: ignore
537
515
trainer .fit ()
538
516
539
517
assert torch .equal (
0 commit comments