5
5
import torch
6
6
from packaging import version
7
7
from torch .utils .data import DataLoader
8
+
8
9
from composer .models import ComposerClassifier
9
10
from composer .trainer .trainer import Trainer
10
11
from composer .utils import dist
27
28
28
29
@pytest .mark .parametrize ('model' , [SimpleWeightTiedModel , PartialWeightTiedModel ])
29
30
@pytest .mark .parametrize ('device' , _INIT_DEVICES )
30
- @pytest .mark .parametrize ('optimizer' , [torch .optim .Adam , torch .optim .SGD ])
31
31
@world_size (2 )
32
32
@pytest .mark .gpu
33
33
@pytest .mark .filterwarnings ('ignore:FSDP2 Config/APIs are experimental*:UserWarning' )
@@ -36,7 +36,6 @@ def test_fsdp2_initialization_with_tied_params(
36
36
model : ComposerClassifier ,
37
37
world_size : int ,
38
38
device : str ,
39
- optimizer : type [torch .optim .Optimizer ],
40
39
):
41
40
"""test FSDP2 initialization for a simple model with weight tying and a model where two modules
42
41
from separate submodules have weight tying applied.
@@ -53,26 +52,16 @@ def test_fsdp2_initialization_with_tied_params(
53
52
mp_policy = None ,
54
53
offload_policy = None ,
55
54
)
56
- optimizer = optimizer (model .parameters (), lr = 0.1 )
57
- prepare_fully_shard (model = model .module , optimizer = optimizer , fsdp2_config = fsdp2_config )
55
+ optimizer = torch . optim . Adam (model .parameters (), lr = 0.1 )
56
+ prepare_fully_shard (model = model .module , optimizers = optimizer , fsdp2_config = fsdp2_config )
58
57
59
58
# Initialization checks
60
59
assert len (model .mlp ._forward_pre_hooks ) == 1 , 'Expected 1 forward pre-hook on the mlp module'
61
60
assert len (model .mlp .fc1 ._forward_pre_hooks ) == 0 , 'Expected 0 forward pre-hook on the fc1 module'
62
61
assert len (model .mlp .fc2 ._forward_pre_hooks ) == 0 , 'Expected 0 forward pre-hook on the fc2 module'
63
62
assert len (model .module ._forward_pre_hooks ) == 1 , 'Expected 1 forward pre-hook on the root module'
64
-
65
- # Check that the weights are DTensor
66
63
assert isinstance (model .mlp .fc1 .weight , DTensor ), 'mlp.fc1.weight should be a DTensor'
67
64
assert isinstance (model .mlp .fc2 .weight , DTensor ), 'mlp.fc2.weight should be a DTensor'
68
- # Check that all optimizer parameters are DTensor (we only have one param group)
69
- assert len (optimizer .param_groups ) == 1 , 'Expected 1 param group in optimizer'
70
- assert len (optimizer .param_groups [0 ]['params' ]) >= 1 , 'Expected at least 1 parameter in optimizer (depends on the model)'
71
- assert all (isinstance (param , DTensor ) for param in optimizer .param_groups [0 ]['params' ]), 'All parameters in optimizer should be DTensor'
72
- # Validate that the ids of the parameters in the optimizer exist in the model
73
- model_param_ids = [id (p ) for p in model .parameters ()]
74
- for param in optimizer .param_groups [0 ]['params' ]:
75
- assert id (param ) in model_param_ids , 'Parameter id in optimizer does not match parameter id in model'
76
65
77
66
if isinstance (model , PartialWeightTiedModel ):
78
67
assert len (model .fc3 ._forward_pre_hooks ) == 1 , 'Expected 1 forward pre-hook on the fc3 module'
@@ -98,3 +87,82 @@ def test_fsdp2_initialization_with_tied_params(
98
87
weight_2 = model .mlp .fc2 .weight .full_tensor ()
99
88
assert (model .mlp .fc1 .weight is model .mlp .fc2 .weight )
100
89
assert (torch .equal (weight_1 , weight_2 ))
90
+
91
+
92
+ @pytest .mark .parametrize ('case' , ['all_params_one_group' , 'subset_one_group' , 'multiple_groups' ])
93
+ @pytest .mark .parametrize ('device' , _INIT_DEVICES )
94
+ @world_size (2 )
95
+ @pytest .mark .gpu
96
+ @pytest .mark .filterwarnings ('ignore:FSDP2 Config/APIs are experimental*:UserWarning' )
97
+ @pytest .mark .skipif (SKIP_TEST , reason = 'FSDP2 is not available in torch < 2.6.0' )
98
+ def test_fsdp2_optimizer_handling (
99
+ case : str ,
100
+ world_size : int ,
101
+ device : str ,
102
+ ):
103
+ """Test FSDP2 correctly updates optimizer state for various configurations."""
104
+ del world_size
105
+
106
+ num_classes = 10
107
+ model = PartialWeightTiedModel (num_features = num_classes , device = device )
108
+ dataset = RandomClassificationDataset (shape = (num_classes ,), size = 10 , num_classes = num_classes )
109
+ dataloader = DataLoader (dataset , sampler = dist .get_sampler (dataset ))
110
+
111
+ all_params_list = list (model .parameters ())
112
+ fc1_params_list = list (model .mlp .fc1 .parameters ())
113
+ fc3_params_list = list (model .fc3 .parameters ())
114
+ remaining_params_list = [
115
+ p for p in all_params_list
116
+ if not any (p is param for param in fc1_params_list )
117
+ and not any (p is param for param in fc3_params_list )
118
+ and not any (p is param for param in model .mlp .fc2 .parameters ()) # To avoid double counting the tied parameter
119
+ ]
120
+
121
+ if case == 'all_params_one_group' :
122
+ optimizer_input = [{'params' : all_params_list , 'lr' : 0.01 }]
123
+ elif case == 'subset_one_group' :
124
+ optimizer_input = [{'params' : fc1_params_list , 'lr' : 0.02 }]
125
+ elif case == 'multiple_groups' :
126
+ optimizer_input = [
127
+ {'params' : fc1_params_list , 'lr' : 0.01 },
128
+ {'params' : fc3_params_list , 'lr' : 0.02 },
129
+ {'params' : remaining_params_list , 'lr' : 0.03 },
130
+ ]
131
+
132
+ optimizer = torch .optim .Adam (optimizer_input )
133
+
134
+ fsdp2_config = FSDP2Config (
135
+ device_mesh = None ,
136
+ reshard_after_forward = True ,
137
+ mp_policy = None ,
138
+ offload_policy = None ,
139
+ )
140
+ prepare_fully_shard (model = model .module , optimizers = optimizer , fsdp2_config = fsdp2_config )
141
+
142
+ def validate_optimizer_state (current_optimizer : torch .optim .Optimizer , stage : str ):
143
+ assert len (current_optimizer .param_groups ) == len (optimizer_input ), \
144
+ f"[{ case } /{ stage } ] Group count mismatch. Expected { len (optimizer_input )} , Got { len (current_optimizer .param_groups )} "
145
+ for i , group in enumerate (current_optimizer .param_groups ):
146
+ opt_params = group ['params' ]
147
+ assert len (opt_params ) == len (optimizer_input [i ]['params' ]), \
148
+ f"[{ case } /{ stage } ] Group { i } : Param count mismatch. Expected { len (optimizer_input [i ]['params' ])} , Got { len (opt_params )} "
149
+ assert all (isinstance (p , DTensor ) for p in opt_params ), \
150
+ f"[{ case } /{ stage } ] Group { i } : Not all parameters are DTensors"
151
+ assert group ['lr' ] == optimizer_input [i ]['lr' ], \
152
+ f"[{ case } /{ stage } ] Group { i } : LR mismatch. Expected { optimizer_input [i ]['lr' ]} , Got { group ['lr' ]} "
153
+
154
+ validate_optimizer_state (optimizer , stage = "after_fully_shard" )
155
+
156
+ model .to_empty (device = 'cuda' )
157
+ for module in model .modules ():
158
+ model .param_init_fn (module )
159
+
160
+ trainer = Trainer (
161
+ model = model ,
162
+ optimizers = optimizer ,
163
+ train_dataloader = dataloader ,
164
+ max_duration = '1ba' ,
165
+ )
166
+ trainer .fit ()
167
+
168
+ validate_optimizer_state (optimizer , stage = "after_fit" )
0 commit comments