Skip to content

Releases: NVIDIA/TransformerEngine

v2.6

15 Sep 21:12
Compare
Choose a tag to compare

Release Notes – Release 2.6

Key Features and Enhancements

  • [PyTorch] Added support for gradient accumulation fusion when using FSDP from megatron-core.
  • [PyTorch] Optimized memory usage when using NVIDIA® CUDA® graphs with TE using the make_graphed_callables function.
  • [PyTorch] Optimized performance of permute fusion kernels for MoE.
  • [PyTorch] Added support for ONNX export of Transformer Engine modules.
  • [PyTorch] Added a save_original_input option to the Linear and GroupedLinear modules to decouple row-wise (forward) and column-wise (backward) quantization. This option saves memory for certain workloads and training recipes.
  • [PyTorch] Improved performance of MXFP8 quantization kernels.
  • [Core] Improved performance of KV caching kernels.

Fixed Issues

  • [PyTorch] Fixed an issue in the LayerNormLinear module where the returned normalization output was of different shape than the input tensor.
  • [PyTorch] Fixed an issue with the align_size calculation in FP8 padding/unpadding modules.
  • [PyTorch] Made miscellaneous fixes and enhancements to the fusible ops (te.sequential) API.
  • [PyTorch] Reduced CPU overhead in various workloads: DelayedScaling recipe, MXFP8 MoE, and pipeline parallelism.
  • [PyTorch] Fixed a bug in the multi-tensor adam kernel that incorrectly downcast an FP32 tensor to BF16.
  • [PyTorch] Fixed an issue with caching FP8 weights when running validation steps between training steps.
  • [PyTorch] Fixed a logical error that could lead to using an suboptimal attention backend when a better-performing backend is available.
  • [PyTorch] Fixed miscellaneous errors during runtime loading of shared libraries by expanding search paths.
  • [PyTorch] Fixed a “use after-free” in cases where quantization and normalization are unfused.
  • [Jax] Fixed a crash with grouped GEMM in CUDA version ≥ 12.9.1.
  • [Jax] Fixed build with JAX v0.7.0 that failed due to removal of jax.extend.ffi.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

Miscellaneous

There are no miscellaneous issues in this release.

v2.5

28 Jul 16:52
Compare
Choose a tag to compare

Release Notes – Release 2.5

Key Features and Enhancements

  • Added support for Python 3.12+
  • Added support for head dimension (head_dim) > 128 for attention for all architectures.
  • [Jax] Added support for sliding window attention (SWA) in context parallel ring attention using THD format and striped sharding.
  • [Jax] Improved performance for per-tensor scaling FP8 recipe.
  • [Jax] Added MXFP8 support for the GroupedDense module and handle the case with zero input tokens.
  • [PyTorch] Enabled FP8 tensor-parallel communication for FP8 block scaling recipe for Hopper by supporting coalesced gather of FP8 quantized tensors.
  • [PyTorch] Optimized MXFP8 Userbuffers implementation by overlapping wgrad NCCL all-gather with dgrad GEMM..
  • [PyTorch] Added support for CPU offloading when using FP8 parameters.
  • [PyTorch] Added support for Context Parallel for Multi Latent Attention (MLA).
  • [PyTorch] Reduced CPU overhead in MoE.
  • [C][PyTorch] Improved performance for FP8 padding and unpadding kernels for MoE.
  • [PyTorch] Added support for FP8 current scaling in operation-based API.

Fixed Issues

  • [Jax] Fixed a numerical error in the scaled masked softmax kernel.
  • [Jax] Fixed output dtype for FP8 GEMM.
  • [PyTorch] Fixed a bug that appeared when the FP8 recipe is changed in between training steps.
  • [PyTorch] Made miscellaneous fixes in TransformerLayer: Pass missing arguments cu_seqlens and max_seqlen to cross-attention and allow attn_input_format=thd.
  • [PyTorch] Fixed a crash when loading checkpoints from previously generated Transformer Engine versions.
  • [PyTorch] Made miscellaneous fixes in CPU offloading logic.
  • [PyTorch] Fixed a numerical issue in cross-entropy loss.
  • [C][PyTorch][Jax] Fixed source installation when using NVTE_FRAMEWORK=all.
  • [PyTorch] Fixed a crash in GroupedLinear when using CUDA graphs.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

Miscellaneous

There are no miscellaneous issues in this release.

v2.4

05 Jun 20:44
Compare
Choose a tag to compare

Release Notes – Release 2.4

Key Features and Enhancements

  • [Jax] Added support for Float8CurrentScaling recipe.
  • [Jax] Added support for logical partitioning axes in TE Flax modules.
  • [Core] Added multiple experimental functions to the C API.
  • [PyTorch] Improved performance by caching device properties.
  • [PyTorch] Made miscellaneous minor improvements to reduce memory consumption for certain workloads.
  • [PyTorch] Added support for MXFP8 recipe when using userbuffers for overlapping TP communication and GEMMs.
  • [PyTorch] Reduced the binary size of the framework extension library from 108 MB to 2 MB.
  • [PyTorch] Introduced a Boolean parameter, rotary_pos_interleaved, in the MultiheadAttention and TransformerLayer modules for interleaved RoPE.
  • [PyTorch] Added support for ignoring tokens in the cross-entropy loss function.
  • [PyTorch] Added support for switching among all supported FP8 recipes during training and checkpointing.
  • [PyTorch] Added various debugging tools via NVIDIA-DL-Framework-Inspect.

Fixed Issues

  • [PyTorch] Fixed a numerical issue when using activation recompute with FP8.
  • [PyTorch] Fixed incorrect output dimensions when using return_layernorm_output in the LayerNormLinear and LayerNormMLP modules.
  • [PyTorch] Fixed a numerical bug when using sequence parallelism in the LayerNorm and RMSNorm modules with Megatron-LM.
  • [PyTorch/Jax] Fixed miscellaneous crashes at import time due to library loading.
  • [Jax] Fixed a crash due to partitioning error when using the LayerNorm or LayerNormMLP module with tensor parallelism.
  • [PyTorch] Fixed an issue where GIL was held during the entirety of C API calls from the framework extensions, including during NVIDIA® CUDA® kernel execution.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

Miscellaneous

There are no miscellaneous issues in this release.

v2.3

14 May 18:04
Compare
Choose a tag to compare

Release Notes – Release 2.3

Key Features and Enhancements

  • [PyTorch] Sped up import of transformer_engine module by moving to a lazy compilation of functions using torch.compile.
  • [PyTorch] Enabled FP8 weights when using FSDP.
  • [C][PyTorch] Added support for Float8 block scaling recipe, as used in the Deepseek v3 paper, for Hopper GPUs.
  • [PyTorch] Made miscellaneous fixes to reduce CPU overhead.
  • [PyTorch] Added support for CPU offloading for activation tensors when using FP8 attention.
  • [PyTorch] Enabled MXFP8 recipe for the GroupedLinear module.
  • [PyTorch] Added a feature to support decoupling the weight gradient compute from the backward function of Transformer Engine modules. This allows users to call backward wgrad and gives them finer-grained control over when gradients are called to support certain advanced parallelism/overlap schemes.
  • [PyTorch] Added support for staggered application of rope embedding to a sequence of inputs in a batch, depending on their starting positions.
  • [All] Added support for RTX 5090.

Fixed Issues

  • [PyTorch] Fixed a numerical bug with use of custom DDP from megatron-core.
  • [PyTorch] Fixed a crash when using the checkpoint method for activation recompute on non-Transformer Engine modules.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

  • [Jax] Praxis layers have been removed, as PAXML is no longer supported.

Deprecated Features

  • The installation for Transformer Engine now requires use of the –no-build-isolation flag when using PyPI package or building from source. Support for installations with build isolation will be removed in a future release.
  • [PyTorch] CPU offloading weight tensors is deprecated.

v2.2

28 Apr 21:29
Compare
Choose a tag to compare

Release Notes – Release 2.2

Key Features and Enhancements

  • [PyTorch] Added support for per-tensor current scaling recipe.
  • [PyTorch] Implemented cross-entropy loss with support for splitting computation across multiple devices.
  • [PyTorch] Added support for CPU offloading with Megatron-Core style distributed optimizers.
  • [PyTorch] Added support for KV cache for FusedAttention, FlashAttention, and UnfusedDotProductAttention backends.
  • [PyTorch] Improved bulk TP communication overlap by launching GEMMs on lower priority streams.
  • [C/PyTorch] Improved performance for P2P-based Tensor Parallel (TP) communication overlap.
  • [Jax] Added support for THD format with ring attention.
  • [Jax] Improved performance and memory usage for causal mask in the cuDNN attention backend.
  • [C] Added multi-node support for NVIDIA® NVLink for TP overlap with userbuffers.

Fixed Issues

  • [PyTorch] Fixed convergence when using context parallelism with a fused attention backend.
  • [PyTorch] Fixed a crash using GroupedLinear when the last input has no tokens.
  • [PyTorch] Made miscellaneous fixes to improve overall performance of the MXFP8 recipe.
  • [PyTorch] Reintroduced support for return_bias argument to all modules, which was silently ignored in v2.0 and v2.1.
  • [PyTorch] Reintroduced support for FP8 communication for overlapping reduce-scatter and GEMM when using TP overlap with userbuffers.
  • [PyTorch] Fixed gradient accumulation fusion in the LayerNormMLP module.
  • [C/PyTorch] Made miscellaneous numerical fixes to the fused attention backend.
  • [C] Avoided creating a new cublasLtHandle for every GEMM call to avoid memory leaks.
  • [Jax] Fixed shape and sharding inference in fused-attention C++ extension.
  • [Jax] Fixed an import error in the encoder example.

Known Issues in This Release

  • RTX 5090 is currently unsupported for FP8 execution. Support will be added in v2.3.0.
  • Transformer Engine may crash when it is installed via the PyPI registry but is run in an environment with CUDA version < 12.8. A temporary workaround is to install from source until the issue is fixed.

Breaking Changes in This Release

  • [PyTorch] The deprecated interval argument for the DelayedScaling recipe has been removed.
  • [PyTorch] There are multiple breaking changes in the InferenceParams class.
    • New arguments num_heads_kv, head_dim_k, and dtype are required during initialization.
    • The user must call a pre_step method to update the InferenceParams state.
    • The swap_key_value_dict method has been removed, as the step method now automatically reorders the key/value sequences according to their batch indices.

Deprecated Features

There are no deprecated features in this release.

Miscellaneous

  • [PyTorch] The minimum required PyTorch version is changed to 2.1.

v2.1

17 Mar 17:19
Compare
Choose a tag to compare

Release Notes – Release 2.1

Key Features and Enhancements

  • [PyTorch] Made the API for fused optimizers (Adam and SGD) consistent with the PyTorch equivalents.
  • [PyTorch] Implemented probability permutation and mask-based permutation in MoE.
  • [PyTorch] Added the store_param_remainders argument for TE optimizers to save memory when storing FP32 master weights for BF16 model weights.
  • [Jax] Added support for THD attention input format for the flax modules.

Fixed Issues

  • [PyPI] Fixed an issue when TE is installed from PyPI in an environment where TE has already been installed from source. The wheel installation was incorrect, resulting in an application crash at runtime.
  • [PyTorch] Fixed an issue with QuantizedTensor types when executing operations such as chunk or split, which have different shapes for input and output.
  • [PyTorch] Made miscellaneous fixes to attention backend for execution on blackwell GPUs.
  • [PyTorch] Fixed a crash when using Context Parallelism with FP8 weights.
  • [PyTorch] Fixed a crash when using fused gradient accumulation with grouped GEMMs (MoE).
  • [Jax/Flax] Changed flax modules to use dtype to initialize their parameters while inferring compute type from the input data type.

Known Issues in This Release

  • [PyTorch] The return_bias option in LayerNormLinear and LayerNormMLP, used internally in TransformerLayer, is silently ignored in this release, resulting in a wrong answer. This issue was resolved in #1569 and the fix will be part of the 2.2 release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

  • [Jax] The fused_attn_thd API call is deprecated in favor of fused_attn, which supports THD format.
  • [Jax] The mask positional argument is deprecated in favor of sequence_descriptor.

v2.0

13 Feb 22:06
Compare
Choose a tag to compare

Release Notes – Release 2.0

Key Features and Enhancements

  • [C] Added MXFP8 support in functions for casting, GEMMs, normalization, activations.
  • [C] Added generic API for quantized tensors, including generic quantize and dequantize functions.
  • [C] Exposed cuDNN LayerNorm and RMSNorm kernels.
  • [pyTorch] Added MXFP8 recipe.
  • [pyTorch] Added MXFP8 support in Linear, LayerNormLinear, LayerNormMLP, and TransformerLayer modules, and in the operation-based API.
  • [pyTorch] Changed the default quantization scheme from FP8 to MXFP8 for Blackwell GPUs.
  • [pyTorch] Added a custom tensor class for MXFP8 data.
  • [pyTorch] Reduced CPU overhead in FP8/MXFP8 execution.
  • [pyTorch] Enabled efficient handling of FP8 parameters with PyTorch FSDP2.
  • [pyTorch] Expanded the support matrix for Sliding Window Attention.

Fixed Issues

  • [pyTorch] Fixed bugs in capturing CUDA Graphs for MoE models.
  • [pyTorch] Fixed errors with THE FP8 state when loading HuggingFace checkpoints.

Known Issues in This Release

  • [pyTorch] Overlapping tensor-parallel communication with Userbuffers is not supported with MXFP8.
  • [pyTorch] When running linear modules with MXFP8, the memory footprint and tensor-parallel communication volume is larger than necessary.
  • [pyTorch] Userbuffers support in the operation-based API is disabled.
  • [PyTorch] The return_bias option in LayerNormLinear and LayerNormMLP, used internally in TransformerLayer, is silently ignored in this release, resulting in a wrong answer. This issue was resolved in #1569 and the fix will be part of the 2.2 release.

Breaking Changes in This Release

  • [C] Updated minimum requirements to CUDA 12.1 and cuDNN 9.3.
  • [PaddlePaddle] Removed PaddlePaddle integration.
  • [pyTorch] Changed the default quantization from FP8 to MXFP8 for Blackwell GPUs.
  • [pyTorch] Removed support for exporting ONNX models. Support for ONNX export will be reenabled in a future release

Deprecated Features

There are no deprecated features in this release.

v1.13

09 Dec 19:29
Compare
Choose a tag to compare

Release Notes – Release 1.13

Key Features and Enhancements

  • [C/PyTorch/Jax] Added support for THD layout for MQA/GQA.
  • [Jax] Expanded FFI (Foreign Function Interface) support to include quantization, transpose, layernorms, fused-attention, and CUDA graphs; fixed miscellaneous bugs in the existing FFI implementations.
  • [Jax] Added support for Ring attention for context parallelism.
  • [PyTorch] Expanded support for the Sequential/Operations Based API to include activations, communication overlap, normalizations, and other fusions.
  • [PyTorch] Made miscellaneous fixes to reduce CPU overhead during execution.
  • [PyTorch] Leveraged cuDNN 9.6+ to reduce memory usage for THD input format to attention.

Fixed Issues

  • [PyTorch] Fixed a crash that could occur when using FlashAttention with context parallelism.
  • [C/Jax] Adopted 64-bit offsets to fix overflow for large tensors in the cuDNN attention back end.
  • [C/Jax] Fixed build when using clang compiler to build JAX native extensions.
  • [PyTorch] Fixed a crash when importing transformer-engine in CPU-only systems.
  • [PyTorch] Fixed a crash when using context parallelism with RoPE.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

  • Transformer Engine support for the PaddlePaddle framework is deprecated, and will be fully removed in version 2.0.
  • Support for exporting Transformer Engine modules via ONNX is deprecated, and will be removed in version 2.0. This feature will be supported again in a later minor release of version 2.

v1.12

18 Nov 22:52
Compare
Choose a tag to compare

Release Notes – Release 1.12

Key Features and Enhancements

  • [pyTorch] Added rotary_base argument for RoPE instead of hard-coding the value to 10000.
  • [pyTorch] Added support for the pool argument in the make_graphed_callables API.
  • [pyTorch] Made miscellaneous minor improvements to mitigate CPU overhead.
  • [pyTorch/C] Fixed window size calculation when using cuDNN attention backend.
  • [pyTorch] Expanded fused RoPE kernel support to include Context parallelism and “thd” qkv-format.
  • [pyTorch] Made flash-attn an optional dependency.
  • [JAX] Added support for sliding window attention.

Fixed Issues

  • [pyTorch/C] Fixed window size calculation when using cuDNN attention backend.
  • [pyTorch] Fixed miscellaneous bugs in the flash-attn version 3 backend.
  • [pyTorch] Fixed an issue using the flash-attn backend with Context Parallelism.
  • [pyTorch] Fixed a numerical error when using FP8 with activation recompute.
  • [pyTorch] Fixed an issue in the backward pass of the GroupedLinear class when weights don’t require gradient.
  • [JAX] Fixed a numerical bug in the cuDNN attention backend when using Context Parallelism.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

v1.11

08 Oct 21:27
Compare
Choose a tag to compare

Release Notes – Release 1.11

Key Features and Enhancements

  • [pyTorch] Added dtensor support for optimizers.
  • [pyTorch] Added context parallel implementation with QKV all-to-all collectives.
  • [pyTorch] Added support for CPU offloading when using FP8 attention.
  • [pyTorch] Implemented padding and unpadding modules for FP8 that improve e2e performance of MoE models by ~2%.
  • [C/pyTorch] Added support for permutation operations for MoE and exposed them in the C API.
  • [pyTorch] Added support for RoPE when using FP8 attention.
  • [pyTorch] Added support for FlashAttention-3.
  • [JAX] Implemented context parallel fused attention using allgather and reduce-scatter collectives.

Fixed Issues

  • [pyTorch] Fixed a crash in fused adam optimizer when master parameters are not set.
  • [pyTorch] Fix a crash when using activation recompute with Python 3.10.
  • [pyTorch] Made miscellaneous fixes in the logic to select the correct attention backend.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.