Skip to content

Commit 7a7225c

Browse files
committed
Release v1.12
2 parents c27ee60 + 7f2afaa commit 7a7225c

File tree

105 files changed

+4410
-1575
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+4410
-1575
lines changed

.github/workflows/trigger-ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ jobs:
3636
|| github.actor == 'yaox12'
3737
|| github.actor == 'huanghua1994'
3838
|| github.actor == 'mgoldfarb-nvidia'
39+
|| github.actor == 'pggPL'
40+
|| github.actor == 'vasunvidia'
41+
|| github.actor == 'erhoo82'
3942
)
4043
steps:
4144
- name: Check if comment is issued by authorized person

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ develop-eggs/
3939
dist/
4040
downloads/
4141
.pytest_cache/
42+
compile_commands.json

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 146 files

build_tools/VERSION.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.11.0
1+
1.12.0

build_tools/pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def setup_pytorch_extension(
8080
)
8181
)
8282

83-
if "80" in cuda_architectures:
84-
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
85-
if "90" in cuda_architectures:
86-
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
83+
for arch in cuda_architectures.split(";"):
84+
if arch == "70":
85+
continue # Already handled
86+
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
8787

8888
# Libraries
8989
library_dirs = []

docs/faq.rst

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
..
2+
Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
4+
See LICENSE for license information.
5+
6+
Frequently Asked Questions (FAQ)
7+
================================
8+
9+
FP8 checkpoint compatibility
10+
----------------------------
11+
12+
Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted.
13+
14+
Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below.
15+
16+
.. code-block:: python
17+
18+
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
19+
>>> with fp8_model_init(enabled=True):
20+
... mha = MultiheadAttention(
21+
... hidden_size=1024,
22+
... num_attention_heads=16,
23+
... bias=True,
24+
... params_dtype=torch.bfloat16,
25+
... input_layernorm=False,
26+
... fuse_qkv_params=True,
27+
... attention_type="self",
28+
... qkv_weight_interleaved=True,
29+
... ).to(dtype=torch.bfloat16, device="cuda")
30+
...
31+
>>> state_dict = mha.state_dict()
32+
>>> print(state_dict.keys())
33+
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])
34+
35+
Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions.
36+
37+
.. list-table::
38+
39+
* - **Version: <= 1.5**
40+
41+
- Saves no FP8 metadata since FP8 attention is not supported
42+
- Loading behavior for checkpoints created by the following versions:
43+
44+
:<= 1.5: Loads no FP8 metadata
45+
:> 1.5: Error: unexpected key
46+
* - **Version: 1.6, 1.7**
47+
48+
- Saves FP8 metadata to `core_attention.fused_attention._extra_state`
49+
- Loading behavior for checkpoints created by the following versions:
50+
51+
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
52+
:1.6, 1.7: Loads FP8 metadata from checkpoint
53+
:>= 1.8: Error: unexpected key
54+
* - **Version: >=1.8, <= 1.11**
55+
56+
- Saves FP8 metadata to `core_attention._extra_state`
57+
- Loading behavior for checkpoints created by the following versions:
58+
59+
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
60+
:1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by
61+
62+
.. code-block:: python
63+
64+
>>> state_dict["core_attention._extra_state"] = \
65+
state_dict["core_attention.fused_attention._extra_state"]
66+
>>> del state_dict["core_attention.fused_attention._extra_state"]
67+
68+
:>= 1.8: Loads FP8 metadata from checkpoint
69+
* - **Version: >=1.12**
70+
71+
- Saves FP8 metadata to `core_attention._extra_state`
72+
- Loading behavior for checkpoints created by the following versions:
73+
74+
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
75+
:>= 1.6: Loads FP8 metadata from checkpoint

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Transformer Engine documentation
3030

3131
installation
3232
examples/quickstart.ipynb
33+
faq
3334

3435
.. toctree::
3536
:hidden:

examples/README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Examples
2+
3+
We provide a variety of examples for deep learning frameworks including [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/jax-ml/jax), and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle).
4+
Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/TransformerEngine/tree/main/docs/examples) and a selection of [third-party examples](#third-party). Please be aware that these third-party examples might need specific, older versions of dependencies to function properly.
5+
6+
# PyTorch
7+
8+
- [Accelerate Hugging Face Llama models with TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb)
9+
- Provides code examples and explanations for integrating TE with the LLaMA2 and LLaMA2 models.
10+
- [PyTorch FSDP with FP8](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/fsdp)
11+
- **Distributed Training**: How to set up and run distributed training using PyTorch’s FullyShardedDataParallel (FSDP) strategy.
12+
- **TE Integration**: Instructions on integrating TE/FP8 with PyTorch for optimized performance.
13+
- **Checkpointing**: Methods for applying activation checkpointing to manage memory usage during training.
14+
- [Attention backends in TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/attention/attention.ipynb)
15+
- **Attention Backends**: Describes various attention backends supported by Transformer Engine, including framework-native, fused, and flash-attention backends, and their performance benefits.
16+
- **Flash vs. Non-Flash**: Compares the flash algorithm with the standard non-flash algorithm, highlighting memory and computational efficiency improvements.
17+
- **Backend Selection**: Details the logic for selecting the most appropriate backend based on availability and performance, and provides user control options for backend selection.
18+
- [Overlapping Communication with GEMM](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/comm_gemm_overlap)
19+
- Training a TE module with GEMM and communication overlap, including various configurations and command-line arguments for customization.
20+
- [Performance Optimizations](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/advanced_optimizations.ipynb)
21+
- **Multi-GPU Training**: How to use TE with data, tensor, and sequence parallelism.
22+
- **Gradient Accumulation Fusion**: Utilizing Tensor Cores to accumulate outputs directly into FP32 for better numerical accuracy.
23+
- **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency.
24+
- [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb)
25+
- Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage.
26+
- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb)
27+
- Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm.
28+
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist)
29+
30+
# JAX
31+
- [Basic Transformer Encoder Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/encoder)
32+
- Single GPU Training: Demonstrates setting up and training a Transformer model using a single GPU.
33+
- Data Parallelism: Scale training across multiple GPUs using data parallelism.
34+
- Model Parallelism: Divide a model across multiple GPUs for parallel training.
35+
- Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup.
36+
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist)
37+
38+
# PaddlePaddle
39+
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/paddle/mnist)
40+
41+
# Third party
42+
- [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine)
43+
- Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3.

examples/jax/encoder/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
"""Shared functions for the encoder tests"""
5+
from functools import lru_cache
6+
7+
from transformer_engine.transformer_engine_jax import get_device_compute_capability
8+
9+
10+
@lru_cache
11+
def is_bf16_supported():
12+
"""Return if BF16 has hardware supported"""
13+
gpu_arch = get_device_compute_capability(0)
14+
return gpu_arch >= 80

examples/jax/encoder/test_model_parallel_encoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import transformer_engine.jax as te
2323
import transformer_engine.jax.flax as te_flax
2424

25+
from common import is_bf16_supported
26+
2527
DEVICE_DP_AXIS = "data"
2628
DEVICE_TP_AXIS = "model"
2729
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
@@ -434,6 +436,7 @@ def setUpClass(cls):
434436
"""Run 3 epochs for testing"""
435437
cls.args = encoder_parser(["--epochs", "3"])
436438

439+
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
437440
def test_te_bf16(self):
438441
"""Test Transformer Engine with BF16"""
439442
actual = train_and_evaluate(self.args)
@@ -446,6 +449,7 @@ def test_te_fp8(self):
446449
actual = train_and_evaluate(self.args)
447450
assert actual[0] < 0.45 and actual[1] > 0.79
448451

452+
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
449453
def test_te_bf16_sp(self):
450454
"""Test Transformer Engine with BF16 + SP"""
451455
self.args.enable_sp = True

0 commit comments

Comments
 (0)