Skip to content

Commit c565646

Browse files
buptzybpre-commit-ci[bot]timmoon10
authored andcommitted
[PyTorch] Support bf16+fp8 cudagraph (NVIDIA#2098)
* support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]>
1 parent fce4e98 commit c565646

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

transformer_engine/pytorch/graph.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def make_graphed_callables(
850850
num_warmup_iters: int = 3,
851851
allow_unused_input: bool = False,
852852
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
853-
fp8_enabled: bool = False,
853+
fp8_enabled: SingleOrTuple[bool] = False,
854854
fp8_calibrating: bool = False,
855855
fp8_recipe: Optional[Recipe] = None,
856856
fp8_group: Optional[dist_group_type] = None,
@@ -896,8 +896,9 @@ def make_graphed_callables(
896896
897897
FP8-related parameters
898898
----------------------
899-
fp8_enabled: bool, default = `True`
900-
whether or not to enable fp8
899+
fp8_enabled: (tuple of) bool, default = `False`
900+
whether or not to enable fp8.
901+
If tuple, the length must match the number of modules.
901902
fp8_calibrating: bool, default = `False`
902903
calibration mode allows collecting statistics such as amax and scale
903904
data of fp8 tensors even when executing without fp8 enabled. This is
@@ -919,17 +920,25 @@ def make_graphed_callables(
919920
"""
920921
set_capture_start()
921922

922-
if fp8_enabled and fp8_recipe is None:
923-
fp8_recipe = get_default_fp8_recipe()
924-
elif not fp8_enabled:
925-
fp8_recipe = None
926-
927923
# Handle single module.
928924
just_one_callable = False
929925
if not isinstance(modules, tuple):
930926
just_one_callable = True
931927
modules = (modules,)
932928

929+
if not isinstance(fp8_enabled, tuple):
930+
assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools"
931+
fp8_enabled = (fp8_enabled,) * len(modules)
932+
else:
933+
assert len(fp8_enabled) == len(
934+
modules
935+
), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})"
936+
if any(fp8_enabled) and fp8_recipe is None:
937+
fp8_recipe = get_default_fp8_recipe()
938+
elif not any(fp8_enabled):
939+
fp8_recipe = None
940+
module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled))
941+
933942
# Store FP8 tensors to reset later.
934943
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
935944

@@ -944,15 +953,15 @@ def wrap_autocast(block):
944953
old_call_funcs[block_cls] = block_cls.__call__
945954

946955
# Wrap the original call function of the module class.
947-
def call_func(*args, **kwargs):
956+
def call_func(self, *args, **kwargs):
948957
with fp8_autocast(
949-
enabled=fp8_enabled,
958+
enabled=module_uses_fp8.get(id(self), False),
950959
calibrating=fp8_calibrating,
951960
fp8_recipe=fp8_recipe,
952961
fp8_group=fp8_group,
953962
_graph=True,
954963
):
955-
outputs = old_call_funcs[block_cls](*args, **kwargs)
964+
outputs = old_call_funcs[block_cls](self, *args, **kwargs)
956965
return outputs
957966

958967
block_cls.__call__ = call_func

0 commit comments

Comments
 (0)