Skip to content

Commit 0272447

Browse files
committed
fixed small bugs
1 parent d5beb0d commit 0272447

File tree

5 files changed

+23
-21
lines changed

5 files changed

+23
-21
lines changed

composer/distributed/dist_strategy.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
get_mixed_precision,
3232
set_custom_fsdp_module_kwargs,
3333
)
34-
from composer.distributed.shared_utils import add_fsdp_oom_hooks
34+
from composer.distributed.shared_utils import add_fsdp_oom_hooks, validate_model_requires_state_sync
3535
from composer.utils import FSDPConfig, StringEnum, TPConfig, dist, ensure_tuple, get_device
3636

3737
__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module']
@@ -246,19 +246,7 @@ def prepare_fsdp_module(
246246
_validate_precision(precision, device)
247247

248248
# Check sync_module_states is True for mixed initialization or HSDP
249-
if fsdp_config.sync_module_states == False:
250-
rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0
251-
all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8))
252-
dist.all_reduce(all_ranks_meta, reduce_operation='MIN')
253-
any_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8))
254-
dist.all_reduce(any_ranks_meta, reduce_operation='MAX')
255-
if all_ranks_meta.item() == 0 and any_ranks_meta.item() == 1:
256-
raise ValueError(
257-
'Detected mixed initialization where some ranks have model on cpu or '
258-
'gpu and some ranks are on meta. Either keep all ranks on the same '
259-
"device or set parallelism_config['fsdp']['sync_module_states'] = True. Otherwise, "
260-
'some weights may be randomly initialized when loading a checkpoint.',
261-
)
249+
validate_model_requires_state_sync(model, fsdp_config)
262250

263251
# Handles of FSDP sync hooks if automicrobatching is on
264252
hook_handles = []

composer/distributed/prepare_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from composer.distributed.fsdp2 import prepare_fully_shard, sync_module_states
1616
from composer.distributed.fsdp2_utils import generate_composer_model_policy, sync_optimizer_and_model_params
1717
from composer.distributed.param_init import meta_init
18-
from composer.distributed.shared_utils import update_model_requires_state_sync
18+
from composer.distributed.shared_utils import validate_model_requires_state_sync
1919
from composer.models import ComposerModel
2020
from composer.utils import dist
2121
from composer.utils.parallelism import FSDP2Config
@@ -69,7 +69,7 @@ def _parallelize_model_helper(
6969
2. With sync_module_states: param_init on rank 0 first, then fully_shard, then broadcast the
7070
initialized state to all other ranks. This makes sure that all ranks have rank 0's model state.
7171
"""
72-
update_model_requires_state_sync(model, config)
72+
validate_model_requires_state_sync(model, config)
7373

7474
if config.sync_module_states:
7575
# If we are syncing module states, we assume that rank 0 has the model on CPU/GPU

composer/distributed/shared_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def add_fsdp_oom_hooks(model: torch.nn.Module, device: Optional[Device] = None)
159159
return hook_handles
160160

161161

162-
def update_model_requires_state_sync(model: nn.Module, fsdp_config: FSDP2Config | FSDPConfig) -> None:
162+
def validate_model_requires_state_sync(model: nn.Module, fsdp_config: FSDP2Config | FSDPConfig) -> None:
163163
"""Checks if sync_module_states configuration is compatible with model initialization.
164164
165165
When sync_module_states is False, this function checks that all ranks have their model
@@ -189,12 +189,20 @@ def update_model_requires_state_sync(model: nn.Module, fsdp_config: FSDP2Config
189189
raise ValueError(
190190
'Detected mixed initialization where some ranks have model on cpu or '
191191
'gpu and some ranks are on meta. Please set '
192-
'parallelism_config["fsdp"]["sync_module_states"] = True.',
192+
'parallelism_config["fsdp"]["sync_module_states"] = True. '
193+
'Make sure that rank 0 is the only rank where the model is NOT on meta.',
193194
)
194195

196+
# Note: We do this instead of raising an error if
197+
# (rank=0 && device=meta) || (rank!=0 && device!=meta)
198+
# because that only raises an error on disparate ranks
199+
# and that crashes the program, we instead just assert
200+
# the setup later and note in the earlier error if the setup is invalid
195201
if fsdp_config.sync_module_states and all_ranks_meta.item() == 1:
196202
raise ValueError(
197203
'Detected that all ranks (including rank 0) have model on meta. '
198204
'Will not sync module states. Please set '
199-
'parallelism_config["fsdp"]["sync_module_states"] = False.',
205+
'parallelism_config["fsdp"]["sync_module_states"] = False. '
206+
'If you want to sync module states, make sure that rank 0 '
207+
'is the only rank where the model is NOT on meta.',
200208
)

tests/common/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(
140140
num_features: int,
141141
device: Union[str, torch.device],
142142
num_classes: int = 3,
143-
add_bias: bool = True,
143+
add_bias: bool = False,
144144
):
145145
fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=add_bias)
146146
fc2 = torch.nn.Linear(num_features, num_classes, device=device, bias=add_bias)

tests/trainer/test_fsdp2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,13 @@ def _create_model_with_mixed_init(model_class: type, num_features: int, device:
344344
TestFSDP2MixedInit._set_deterministic_seed(seed)
345345

346346
resolved_device = device if dist.get_local_rank() == 0 else 'meta'
347-
model = model_class(num_features=num_features, device=resolved_device)
347+
348+
# set the bias to be True for SimpleComposerMLP
349+
kwargs = {}
350+
if model_class == SimpleComposerMLP:
351+
kwargs['add_bias'] = True
352+
353+
model = model_class(num_features=num_features, device=resolved_device, **kwargs)
348354
model.add_fsdp_wrap_attribute_to_children()
349355

350356
if dist.get_local_rank() == 0:

0 commit comments

Comments
 (0)