Skip to content

Commit 570fd2e

Browse files
authored
Activation Checkpointing and Offloading for FSDP2 (#3832)
1 parent 3c29f1a commit 570fd2e

File tree

8 files changed

+504
-10
lines changed

8 files changed

+504
-10
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2022 MosaicML Composer authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Helpers for activation checkpointing. Note that while this is orthogonal to FSDP2, it is implemented in the distributed directory because it is closely related to FSDP2."""
5+
6+
from typing import Callable, Optional
7+
8+
import torch
9+
import torch.nn as nn
10+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
11+
CheckpointImpl,
12+
apply_activation_checkpointing,
13+
checkpoint_wrapper,
14+
offload_wrapper,
15+
)
16+
17+
18+
def generate_default_check_fn(model: nn.Module) -> Callable:
19+
"""Generates the default check fn for activation checkpointing/offloading."""
20+
21+
def _check_fn(module: torch.nn.Module) -> bool:
22+
if hasattr(module, '_activation_checkpointing'):
23+
return bool(module._activation_checkpointing)
24+
if hasattr(
25+
model,
26+
'activation_checkpointing_fn',
27+
) and isinstance(model.activation_checkpointing_fn, Callable):
28+
return model.activation_checkpointing_fn(module)
29+
return False
30+
31+
return _check_fn
32+
33+
34+
def apply_ac(
35+
model: nn.Module,
36+
activation_checkpointing: bool,
37+
activation_cpu_offload: bool,
38+
check_fn: Optional[Callable] = None,
39+
) -> None:
40+
"""Apply activation checkpointing to the model. This is orthogonal to FSDP2 so it can be applied pre-sharding or post-sharding.
41+
42+
This method follows the same logic as FSDP1 as well as TorchTitan's AC example.
43+
44+
Args:
45+
model (nn.Module): The model to apply activation checkpointing to.
46+
activation_checkpointing (bool): Whether to apply activation checkpointing.
47+
activation_cpu_offload (bool): Whether to offload activations to the CPU.
48+
check_fn (Optional[Callable]): An optional function to determine if a module should be checkpointed.
49+
"""
50+
# Create the base checkpointing wrapper using no_reentrant checkpointing by default as
51+
# PyTorch notes that reentrant checkpointing is deprecated and will be removed in a future release
52+
opt_checkpoint_wrapper = lambda m: checkpoint_wrapper(
53+
m,
54+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
55+
) if activation_checkpointing else (lambda module: module)
56+
# Create the combined wrapper which takes cpu offloading into consideration
57+
opt_combined_wrapper = (
58+
lambda module: offload_wrapper(
59+
opt_checkpoint_wrapper(module)
60+
if activation_checkpointing else module, # type: ignore reportGeneralTypeIssues
61+
)
62+
) if activation_cpu_offload else opt_checkpoint_wrapper
63+
64+
# Create the check function to determine if a module should be checkpointed
65+
if check_fn is None:
66+
check_fn = generate_default_check_fn(model)
67+
68+
# Apply the activation checkpointing on the model, this uses _recursive_wrap to apply the wrapper to all submodules
69+
# but doesn't apply the wrapper to the root module
70+
apply_activation_checkpointing(model, opt_combined_wrapper, check_fn) # type: ignore
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2022 MosaicML Composer authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Entrypoint for distributed training (using FSDP2)."""
5+
6+
from typing import Callable, Optional
7+
8+
import torch
9+
from torch.distributed.fsdp.wrap import CustomPolicy
10+
11+
from composer.distributed.activation_checkpointing import apply_ac
12+
from composer.distributed.fsdp2 import prepare_fully_shard
13+
from composer.utils.parallelism import FSDP2Config, FSDPConfig
14+
15+
16+
def parallelize_model(
17+
model: torch.nn.Module,
18+
config: FSDP2Config | FSDPConfig,
19+
optimizer: Optional[torch.optim.Optimizer] = None,
20+
fsdp_wrap_policy: Optional[CustomPolicy] = None,
21+
activation_checkpointing_check_fn: Optional[Callable] = None,
22+
):
23+
"""Prepare a model for distributed training.
24+
25+
Args:
26+
model (torch.nn.Module): The model to prepare for distributed training.
27+
config (FSDP2Config | FSDPConfig): The configuration for distributed training. Currently only FSDP2Config is supported.
28+
optimizer (Optional[torch.optim.Optimizer]): The optimizer to use for distributed training.
29+
fsdp_wrap_policy (Optional[CustomPolicy]): The FSDP wrap policy to use for distributed training.
30+
activation_checkpointing_check_fn (Optional[Callable]): The function to use to check if a module's activations should be checkpointed or offloaded.
31+
"""
32+
if isinstance(config, FSDPConfig):
33+
raise ValueError('FSDPConfig is not supported for now, use FSDP2Config instead')
34+
35+
if activation_checkpointing_check_fn is not None:
36+
if not config.activation_checkpointing and not config.activation_cpu_offload:
37+
raise ValueError(
38+
'Activation checkpointing or offloading must be enabled if activation_checkpointing_check_fn is provided',
39+
)
40+
41+
if config.activation_checkpointing or config.activation_cpu_offload:
42+
apply_ac(
43+
model,
44+
config.activation_checkpointing,
45+
config.activation_cpu_offload,
46+
activation_checkpointing_check_fn,
47+
)
48+
49+
prepare_fully_shard(model, optimizer, config, fsdp_wrap_policy)

composer/utils/parallelism.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class FSDP2Config:
7676
# Settable core FSDP2 attrs
7777
device_mesh: Optional[DeviceMesh] = None
7878
reshard_after_forward: bool | int = True
79+
# TODO: If we have reasonable evidence that activation checkpointing/activation offloading is decoupled from FSDP(2)
80+
# in most of our use cases, we can decouple these two attributes from the FSDP2Config class.
81+
activation_checkpointing: bool = False
82+
activation_cpu_offload: bool = False
7983

8084
### Temporary read-only properties for FSDP 1 compatibility ###
8185
# to be supported in FSDP2
@@ -103,10 +107,6 @@ def save_planner(self) -> Optional[Any]:
103107
def sharded_ckpt_prefix_dir(self) -> str:
104108
return 'ep{epoch}-ba{batch}'
105109

106-
@property
107-
def activation_cpu_offload(self) -> bool:
108-
return False
109-
110110
@property
111111
def data_parallel_shard_degree(self) -> int:
112112
return -1

tests/common/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from tests.common.events import EventCounterCallback
1919
from tests.common.markers import device, world_size
2020
from tests.common.models import (
21+
ComposerCounterModel,
2122
ConvModel,
23+
CountModule,
2224
EmbeddedWeightTiedModel,
2325
EmptyModel,
2426
EvenSimplerMLP,
@@ -75,4 +77,6 @@ def get_module_subclasses(module: types.ModuleType, cls: type) -> list[type]:
7577
'EvenSimplerMLP',
7678
'SimpleComposerMLP',
7779
'TPSimpleComposerMLP',
80+
'ComposerCounterModel',
81+
'CountModule',
7882
]

tests/common/models.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,40 @@ def add_fsdp_wrap_attribute_to_children(self):
146146
child._fsdp_wrap = True # type: ignore
147147

148148

149+
class CountModule(torch.nn.Module):
150+
151+
def __init__(self, num_inputs: int, num_outputs: int, device: Union[str, torch.device]):
152+
super().__init__()
153+
self.call_count = 0
154+
self.inner_1 = torch.nn.Linear(num_inputs, num_outputs, device=device, bias=False)
155+
self.inner_2 = torch.nn.Linear(num_outputs, num_outputs, device=device, bias=False)
156+
157+
def forward(self, x):
158+
self.call_count += 1
159+
x = self.inner_1(x)
160+
x = self.inner_2(x)
161+
return x
162+
163+
164+
# A simple MLP with two hidden layers where the module counts the number of times it calls forward
165+
# This is used to test activation checkpointing
166+
class ComposerCounterModel(ComposerClassifier):
167+
168+
def __init__(
169+
self,
170+
num_inputs: int,
171+
num_outputs: int,
172+
device: Union[str, torch.device],
173+
num_hidden_layer_features: int = 8,
174+
):
175+
module = torch.nn.Sequential(
176+
CountModule(num_inputs, num_hidden_layer_features, device),
177+
CountModule(num_hidden_layer_features, num_outputs, device),
178+
)
179+
super().__init__(num_classes=num_outputs, module=module)
180+
self.module = module
181+
182+
149183
# Like SimpleComposerMLP but saves each layer which is necessary to TP to it.
150184
class TPSimpleComposerMLP(ComposerClassifier):
151185

tests/trainer/fsdp2_context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@
1010
SKIP_TEST = version.parse(torch.__version__) < version.parse('2.6.0')
1111
if not SKIP_TEST:
1212
# TODO (FSDP2) move this to top once we deprecate torch 2.5
13-
from composer.distributed import fsdp2
13+
from composer.distributed import activation_checkpointing, fsdp2, prepare_distributed
14+
apply_ac = activation_checkpointing.apply_ac
15+
parallelize_model = prepare_distributed.parallelize_model
1416
prepare_fully_shard = fsdp2.prepare_fully_shard
1517
legalize_param_sharing_between_modules = fsdp2.legalize_param_sharing_between_modules
1618
get_standalone_and_tied_modules = fsdp2.get_standalone_and_tied_modules
1719
_recursive_apply_fully_shard = fsdp2._recursive_apply_fully_shard
1820
_generate_default_policy = fsdp2.generate_default_policy
1921
check_param_tying = fsdp2.check_param_tying
2022
else:
23+
apply_ac = lambda *args, **kwargs: None
24+
parallelize_model = lambda *args, **kwargs: None
2125
prepare_fully_shard = lambda *args, **kwargs: None
2226
legalize_param_sharing_between_modules = lambda *args, **kwargs: None
2327
get_standalone_and_tied_modules = lambda *args, **kwargs: ([], set())

0 commit comments

Comments
 (0)