|
31 | 31 | get_mixed_precision,
|
32 | 32 | set_custom_fsdp_module_kwargs,
|
33 | 33 | )
|
34 |
| -from composer.distributed.shared_utils import add_fsdp_oom_hooks |
| 34 | +from composer.distributed.shared_utils import add_fsdp_oom_hooks, validate_model_requires_state_sync |
35 | 35 | from composer.utils import FSDPConfig, StringEnum, TPConfig, dist, ensure_tuple, get_device
|
36 | 36 |
|
37 | 37 | __all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module']
|
@@ -246,19 +246,7 @@ def prepare_fsdp_module(
|
246 | 246 | _validate_precision(precision, device)
|
247 | 247 |
|
248 | 248 | # Check sync_module_states is True for mixed initialization or HSDP
|
249 |
| - if fsdp_config.sync_module_states == False: |
250 |
| - rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0 |
251 |
| - all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8)) |
252 |
| - dist.all_reduce(all_ranks_meta, reduce_operation='MIN') |
253 |
| - any_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8)) |
254 |
| - dist.all_reduce(any_ranks_meta, reduce_operation='MAX') |
255 |
| - if all_ranks_meta.item() == 0 and any_ranks_meta.item() == 1: |
256 |
| - raise ValueError( |
257 |
| - 'Detected mixed initialization where some ranks have model on cpu or ' |
258 |
| - 'gpu and some ranks are on meta. Either keep all ranks on the same ' |
259 |
| - "device or set parallelism_config['fsdp']['sync_module_states'] = True. Otherwise, " |
260 |
| - 'some weights may be randomly initialized when loading a checkpoint.', |
261 |
| - ) |
| 249 | + validate_model_requires_state_sync(model, fsdp_config) |
262 | 250 |
|
263 | 251 | # Handles of FSDP sync hooks if automicrobatching is on
|
264 | 252 | hook_handles = []
|
|
0 commit comments