Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions composer/distributed/fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Helpers for FSDP2."""

import logging
from typing import Optional

import torch
Expand All @@ -18,6 +19,8 @@
)
from composer.utils.parallelism import FSDP2Config

log = logging.getLogger(__name__)


def _recursive_apply_fully_shard(
root_module: nn.Module,
Expand Down Expand Up @@ -130,3 +133,10 @@ def prepare_fully_shard(
# Check for parameter tying
with check_param_tying(model):
apply_fully_shard(model, fsdp2_config, auto_wrap_policy)

if fsdp2_config.verbose:
log.info(f'FSDP2: Fully sharded model:\n{model}')
for attr in fsdp2_config.settable_attrs():
if attr == 'verbose':
continue
log.info(f'FSDP2: {attr}: {getattr(fsdp2_config, attr)}')
24 changes: 21 additions & 3 deletions composer/distributed/prepare_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

"""Entrypoint for distributed training (using FSDP2)."""

from contextlib import nullcontext
import logging
import time
from contextlib import contextmanager, nullcontext
from typing import Callable, Optional

import torch
Expand All @@ -16,6 +18,20 @@
from composer.models import ComposerModel
from composer.utils.parallelism import FSDP2Config

log = logging.getLogger(__name__)


# TODO put this func into a general util function file
@contextmanager
def log_execution_time(logger: logging.Logger, operation_name: str):
"""Log the execution time of a block of code."""
start_time = time.time()
try:
yield
finally:
end_time = time.time()
logger.info(f'{operation_name} took {end_time - start_time:.2f} seconds')


def parallelize_model(
model: torch.nn.Module,
Expand Down Expand Up @@ -64,8 +80,10 @@ def parallelize_model(

# Use the context manager for optimizer synchronization if optimizer is provided
with sync_optimizer_and_model_params(optimizer, model) if optimizer is not None else nullcontext():
prepare_fully_shard(model, config, fsdp_wrap_policy)
param_init_fn(model)
with log_execution_time(log, 'Prepare FSDP2'):
prepare_fully_shard(model, config, fsdp_wrap_policy)
with log_execution_time(log, 'Meta Init Device'):
param_init_fn(model)
# NOTE appy_ac can not be included in this context as it would wrap and replace the sub-modules thus disqualify FQN of params


Expand Down
2 changes: 1 addition & 1 deletion composer/utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class FSDP2Config:
activation_checkpointing: bool = False
activation_cpu_offload: bool = False

# TODO: add support of versose
verbose: bool = False

@classmethod
def settable_attrs(cls) -> set[str]:
Expand Down
Loading