|
| 1 | +# Copyright 2022 MosaicML Composer authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Helpers for activation checkpointing. Note that while this is orthogonal to FSDP2, it is implemented in the distributed directory because it is closely related to FSDP2.""" |
| 5 | + |
| 6 | +from typing import Callable, Optional |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| 11 | + CheckpointImpl, |
| 12 | + apply_activation_checkpointing, |
| 13 | + checkpoint_wrapper, |
| 14 | + offload_wrapper, |
| 15 | +) |
| 16 | + |
| 17 | + |
| 18 | +def generate_default_check_fn(model: nn.Module) -> Callable: |
| 19 | + """Generates the default check fn for activation checkpointing/offloading.""" |
| 20 | + |
| 21 | + def _check_fn(module: torch.nn.Module) -> bool: |
| 22 | + if hasattr(module, '_activation_checkpointing'): |
| 23 | + return bool(module._activation_checkpointing) |
| 24 | + if hasattr( |
| 25 | + model, |
| 26 | + 'activation_checkpointing_fn', |
| 27 | + ) and isinstance(model.activation_checkpointing_fn, Callable): |
| 28 | + return model.activation_checkpointing_fn(module) |
| 29 | + return False |
| 30 | + |
| 31 | + return _check_fn |
| 32 | + |
| 33 | + |
| 34 | +def apply_ac( |
| 35 | + model: nn.Module, |
| 36 | + activation_checkpointing: bool, |
| 37 | + activation_cpu_offload: bool, |
| 38 | + check_fn: Optional[Callable] = None, |
| 39 | +) -> None: |
| 40 | + """Apply activation checkpointing to the model. This is orthogonal to FSDP2 so it can be applied pre-sharding or post-sharding. |
| 41 | +
|
| 42 | + This method follows the same logic as FSDP1 as well as TorchTitan's AC example. |
| 43 | +
|
| 44 | + Args: |
| 45 | + model (nn.Module): The model to apply activation checkpointing to. |
| 46 | + activation_checkpointing (bool): Whether to apply activation checkpointing. |
| 47 | + activation_cpu_offload (bool): Whether to offload activations to the CPU. |
| 48 | + check_fn (Optional[Callable]): An optional function to determine if a module should be checkpointed. |
| 49 | + """ |
| 50 | + # Create the base checkpointing wrapper using no_reentrant checkpointing by default as |
| 51 | + # PyTorch notes that reentrant checkpointing is deprecated and will be removed in a future release |
| 52 | + opt_checkpoint_wrapper = lambda m: checkpoint_wrapper( |
| 53 | + m, |
| 54 | + checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
| 55 | + ) if activation_checkpointing else (lambda module: module) |
| 56 | + # Create the combined wrapper which takes cpu offloading into consideration |
| 57 | + opt_combined_wrapper = ( |
| 58 | + lambda module: offload_wrapper( |
| 59 | + opt_checkpoint_wrapper(module) |
| 60 | + if activation_checkpointing else module, # type: ignore reportGeneralTypeIssues |
| 61 | + ) |
| 62 | + ) if activation_cpu_offload else opt_checkpoint_wrapper |
| 63 | + |
| 64 | + # Create the check function to determine if a module should be checkpointed |
| 65 | + if check_fn is None: |
| 66 | + check_fn = generate_default_check_fn(model) |
| 67 | + |
| 68 | + # Apply the activation checkpointing on the model, this uses _recursive_wrap to apply the wrapper to all submodules |
| 69 | + # but doesn't apply the wrapper to the root module |
| 70 | + apply_activation_checkpointing(model, opt_combined_wrapper, check_fn) # type: ignore |
0 commit comments