Skip to content

Commit 6872ec1

Browse files
committed
used settable attrs
1 parent 0272447 commit 6872ec1

File tree

4 files changed

+35
-57
lines changed

4 files changed

+35
-57
lines changed

composer/distributed/shared_utils.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -186,23 +186,11 @@ def validate_model_requires_state_sync(model: nn.Module, fsdp_config: FSDP2Confi
186186
requires_sync = all_ranks_meta.item() == 0 and any_ranks_meta.item() == 1
187187

188188
if not fsdp_config.sync_module_states and requires_sync:
189-
raise ValueError(
190-
'Detected mixed initialization where some ranks have model on cpu or '
191-
'gpu and some ranks are on meta. Please set '
192-
'parallelism_config["fsdp"]["sync_module_states"] = True. '
193-
'Make sure that rank 0 is the only rank where the model is NOT on meta.',
194-
)
195-
196-
# Note: We do this instead of raising an error if
197-
# (rank=0 && device=meta) || (rank!=0 && device!=meta)
198-
# because that only raises an error on disparate ranks
199-
# and that crashes the program, we instead just assert
200-
# the setup later and note in the earlier error if the setup is invalid
201-
if fsdp_config.sync_module_states and all_ranks_meta.item() == 1:
202-
raise ValueError(
203-
'Detected that all ranks (including rank 0) have model on meta. '
204-
'Will not sync module states. Please set '
205-
'parallelism_config["fsdp"]["sync_module_states"] = False. '
206-
'If you want to sync module states, make sure that rank 0 '
207-
'is the only rank where the model is NOT on meta.',
208-
)
189+
fsdp_config.sync_module_states = True
190+
191+
# Asserts that the rank setup is valid
192+
if fsdp_config.sync_module_states:
193+
if dist.get_global_rank() == 0:
194+
assert rank_on_meta == 0
195+
else:
196+
assert rank_on_meta == 1

composer/utils/parallelism.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,23 @@ class FSDP2Config:
8080
# in most of our use cases, we can decouple these two attributes from the FSDP2Config class.
8181
activation_checkpointing: bool = False
8282
activation_cpu_offload: bool = False
83-
sync_module_states: bool = False
84-
8583
verbose: bool = False
8684

85+
# Settable attrs that are automatically set during training
86+
_sync_module_states: bool = field(default=False, init=False, repr=False)
87+
88+
@property
89+
def sync_module_states(self) -> bool:
90+
return self._sync_module_states
91+
92+
@sync_module_states.setter
93+
def sync_module_states(self, value: bool):
94+
self._sync_module_states = value
95+
8796
@classmethod
8897
def settable_attrs(cls) -> set[str]:
8998
"""Return a set of all settable attributes of FSDP2Config."""
90-
return {field.name for field in fields(cls)}
99+
return {field.name for field in fields(cls) if not field.name.startswith('_')}
91100

92101
@classmethod
93102
def from_compatible_attrs(cls, attrs: dict[str, Any]) -> 'FSDP2Config':

tests/trainer/test_fsdp2.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def create_trainer_with_model(
3333
activation_checkpointing: bool = False,
3434
activation_cpu_offload: bool = False,
3535
auto_microbatching: bool = False,
36-
sync_module_states: bool = False,
36+
fsdp1_sync_module_states: bool = False,
3737
) -> Trainer:
3838
"""Helper function to create a Trainer with a model, dataloader, and FSDP2 configuration."""
3939
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
@@ -44,12 +44,11 @@ def create_trainer_with_model(
4444
parallelism_config.fsdp2 = FSDP2Config(
4545
activation_checkpointing=activation_checkpointing,
4646
activation_cpu_offload=activation_cpu_offload,
47-
sync_module_states=sync_module_states,
4847
)
4948
else:
5049
parallelism_config.fsdp = FSDPConfig(
5150
state_dict_type='sharded',
52-
sync_module_states=sync_module_states,
51+
sync_module_states=fsdp1_sync_module_states,
5352
)
5453
if optimizer is None:
5554
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
@@ -346,6 +345,7 @@ def _create_model_with_mixed_init(model_class: type, num_features: int, device:
346345
resolved_device = device if dist.get_local_rank() == 0 else 'meta'
347346

348347
# set the bias to be True for SimpleComposerMLP
348+
# which is used for a later test
349349
kwargs = {}
350350
if model_class == SimpleComposerMLP:
351351
kwargs['add_bias'] = True
@@ -360,14 +360,17 @@ def _create_model_with_mixed_init(model_class: type, num_features: int, device:
360360

361361
@staticmethod
362362
def _train_model_and_extract_weights(model, use_fsdp2: bool):
363-
"""Helper function to train a model and extract its weights."""
363+
"""Helper function to train a mixed init model and extract its weights."""
364364
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
365+
kwargs = {}
366+
if not use_fsdp2:
367+
kwargs['fsdp1_sync_module_states'] = True
365368
trainer = create_trainer_with_model(
366369
model=model,
367370
max_duration=f'10ba',
368371
use_fsdp2=use_fsdp2,
369372
optimizer=optimizer,
370-
sync_module_states=True,
373+
**kwargs,
371374
)
372375

373376
trainer.fit()
@@ -403,31 +406,6 @@ def _compare_weights(fsdp1_weights: dict, fsdp2_weights: dict, tolerance: float
403406
assert diff < tolerance, \
404407
f'Weight difference for {name} exceeds tolerance: {diff} > {tolerance}.'
405408

406-
@world_size(2)
407-
@pytest.mark.gpu
408-
# Note that we are testing on a GPU instance just to make sure we can initialize
409-
# on CPU and then move to GPU.
410-
@pytest.mark.parametrize('device', ['cuda', 'cpu'])
411-
def test_fsdp2_sync_module_states_raises_error_when_invalid_mixed_initialization(
412-
self,
413-
world_size: int,
414-
device: str,
415-
):
416-
"""Test that FSDP2 sync_module_states raises an error when invalid mixed initialization is detected."""
417-
del world_size
418-
resolved_device = device if dist.get_local_rank() == 0 else 'meta'
419-
model = self._create_model_with_mixed_init(SimpleComposerMLP, 10, resolved_device)
420-
with pytest.raises(ValueError) as e:
421-
create_trainer_with_model(model=model, num_classes=10, use_fsdp2=True, sync_module_states=False)
422-
assert 'Detected mixed initialization where some ranks have model on cpu or gpu and some ranks are on meta' in str(
423-
e.value,
424-
)
425-
426-
model_2 = self._create_model_with_mixed_init(SimpleComposerMLP, 10, 'meta')
427-
with pytest.raises(ValueError) as e:
428-
create_trainer_with_model(model=model_2, num_classes=10, use_fsdp2=True, sync_module_states=True)
429-
assert 'Detected that all ranks (including rank 0) have model on meta' in str(e.value)
430-
431409
@world_size(2)
432410
@pytest.mark.gpu
433411
# Note that we are testing on a GPU instance just to make sure we can initialize
@@ -458,8 +436,8 @@ def param_init_fn(module: torch.nn.Module):
458436
model=model,
459437
num_classes=10,
460438
use_fsdp2=True,
461-
sync_module_states=True,
462439
)
440+
assert trainer.state.fsdp_config.sync_module_states, 'sync_module_states should be True' # type: ignore
463441

464442
module = trainer.state.model.module # type: ignore
465443
assert torch.equal(
@@ -500,9 +478,8 @@ def test_fsdp2_mixed_init_does_not_break_weight_tying(
500478
model=model,
501479
num_classes=10,
502480
use_fsdp2=True,
503-
sync_module_states=True,
504481
)
505-
482+
assert trainer.state.fsdp_config.sync_module_states, 'sync_module_states should be True' # type: ignore
506483
# Check that the weights are correctly tied after training
507484
trainer.fit()
508485
weight_1 = model.mlp.fc1.weight.full_tensor() # type: ignore
@@ -532,8 +509,9 @@ def test_fsdp2_sync_module_state_aligns_with_optimizer_state(
532509
num_classes=10,
533510
use_fsdp2=True,
534511
optimizer=optimizer,
535-
sync_module_states=True,
536512
)
513+
514+
assert trainer.state.fsdp_config.sync_module_states, 'sync_module_states should be True' # type: ignore
537515
trainer.fit()
538516

539517
assert torch.equal(

tests/trainer/test_fsdp2_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def test_fsdp2config_from_empty_attributes():
6767
assert fsdp2_config.activation_cpu_offload is False # default value
6868
assert fsdp2_config.reshard_after_forward is True # default value
6969
assert fsdp2_config.device_mesh is None # default value
70+
assert fsdp2_config.sync_module_states is False # default value
7071

7172

7273
def test_fsdp2config_from_fsdp1_multiple_invalid_attributes():
@@ -76,6 +77,7 @@ def test_fsdp2config_from_fsdp1_multiple_invalid_attributes():
7677
'invalid_attribute1': 'value1',
7778
'invalid_attribute2': 'value2',
7879
'auto_wrap': True,
80+
'sync_module_states': True,
7981
}
8082

8183
with pytest.warns() as record:
@@ -87,3 +89,4 @@ def test_fsdp2config_from_fsdp1_multiple_invalid_attributes():
8789
assert any('invalid_attribute1: value1' in msg for msg in warning_messages)
8890
assert any('invalid_attribute2: value2' in msg for msg in warning_messages)
8991
assert any('auto_wrap: True' in msg for msg in warning_messages)
92+
assert any('sync_module_states: True' in msg for msg in warning_messages)

0 commit comments

Comments
 (0)