Skip to content

Commit a923abe

Browse files
vcherepanov-nvSelvaraj Anandarajpre-commit-ci[bot]Selvaraj AnandarajpggPL
authored andcommitted
Temporarily remove comm_gemm tests (NVIDIA#2133)
Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Disable determinism for sm100 (NVIDIA#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] ONNX export of FP8 Current Scaling (NVIDIA#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (NVIDIA#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> build: pull cached wheels (NVIDIA#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> feat: Add support for multiple quantization modes in the UB communicators (NVIDIA#2043) Signed-off-by: Varun Thumbe <[email protected]> [Common] Add checks to CUDA kernel launch and CUDA API calls (NVIDIA#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (NVIDIA#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Dropout with 8-bit RNG (NVIDIA#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Create GPU reload buffers on main stream (NVIDIA#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> missed a quant code removal Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> minor bug fix Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Add cuBLASMp-backed GEMM-like API to TE common (NVIDIA#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (NVIDIA#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Delay MeshResource validation until first usage (NVIDIA#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Decouple Recipe and ScalingMode (NVIDIA#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold <[email protected]> * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold <[email protected]> * Format and lint Signed-off-by: Jeremy Berchtold <[email protected]> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold <[email protected]> * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold <[email protected]> * Update test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (NVIDIA#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (NVIDIA#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (NVIDIA#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Temporarily remove comm_gemm tests (NVIDIA#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> [PyTorch] Disable determinism for sm100 (NVIDIA#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (NVIDIA#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (NVIDIA#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> build: pull cached wheels (NVIDIA#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> feat: Add support for multiple quantization modes in the UB communicators (NVIDIA#2043) [Common] Add checks to CUDA kernel launch and CUDA API calls (NVIDIA#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (NVIDIA#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Dropout with 8-bit RNG (NVIDIA#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (NVIDIA#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]>
1 parent 4c75c2f commit a923abe

File tree

75 files changed

+1386
-394
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+1386
-394
lines changed

docs/api/pytorch.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pyTorch
4949

5050
.. autoapifunction:: transformer_engine.pytorch.moe_permute
5151

52-
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
52+
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
5353

5454
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
5555

@@ -62,3 +62,6 @@ pyTorch
6262
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
6363

6464
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
65+
66+
.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
67+
:members: FP8, NONE

docs/examples/onnx/onnx_export.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"<b>Note:</b>\n",
1212
"\n",
13-
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.\n",
13+
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling, FP8 current scaling and MXFP8.\n",
1414
"\n",
1515
"</div>\n",
1616
"\n",

examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)
263263
te.module.base.initialize_ub(
264264
[batched_size, hidden_size],
265265
tp_size,
266-
use_fp8=opts.fp8,
266+
quantization_modes=[
267+
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
268+
],
267269
dtype=torch.bfloat16,
268270
bootstrap_backend=opts.bootstrap_backend,
269271
)

tests/cpp/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,5 @@ include_directories(${CMAKE_SOURCE_DIR})
4343
find_package(CUDAToolkit REQUIRED)
4444
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
4545

46-
add_subdirectory(comm_gemm)
4746
add_subdirectory(operator)
4847
add_subdirectory(util)

tests/pytorch/distributed/run_layer_with_overlap.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import warnings
1313
import pprint
1414
import yaml
15+
from contextlib import nullcontext
16+
from functools import partial
1517

1618
import torch
1719
import torch.distributed as dist
@@ -35,9 +37,10 @@ def __init__(self, module, num_layers, *args, **kwargs):
3537
self.num_layers = num_layers
3638
self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)])
3739

38-
def forward(self, x):
39-
for layer in self.layers:
40-
x = layer(x)
40+
def forward(self, x, layer_contexts):
41+
for layer, context in zip(self.layers, layer_contexts):
42+
with context():
43+
x = layer(x)
4144
return x
4245

4346

@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None):
237240
default=False,
238241
help="Print out additional debug information.",
239242
)
243+
parser.add_argument(
244+
"--first-last-layers-bf16",
245+
action="store_true",
246+
default=False,
247+
help="Use bf16 for first and last N layers.",
248+
)
249+
parser.add_argument(
250+
"--num-layers-at-start-in-bf16",
251+
type=int,
252+
default=0,
253+
help="Number of layers at the start to run in bf16.",
254+
)
255+
parser.add_argument(
256+
"--num-layers-at-end-in-bf16",
257+
type=int,
258+
default=0,
259+
help="Number of layers at the end to run in bf16.",
260+
)
240261
args = parser.parse_args(argv, namespace)
241262

242263
if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
243264
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
244265
args.use_cuda_graphs = False
245266

267+
if not args.first_last_layers_bf16 and (
268+
args.num_layers_at_start_in_bf16 > 0 or args.num_layers_at_end_in_bf16 > 0
269+
):
270+
warnings.warn(
271+
"num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when"
272+
" first-last-layers-bf16 is enabled!"
273+
)
274+
args.num_layers_at_start_in_bf16 = 0
275+
args.num_layers_at_end_in_bf16 = 0
276+
277+
if args.num_layers_at_start_in_bf16 + args.num_layers_at_end_in_bf16 > args.num_layers:
278+
raise ValueError(
279+
"num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to"
280+
" num-layers!"
281+
)
282+
246283
return args
247284

248285

@@ -381,10 +418,17 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
381418
"qkv_dgrad": {"method": "ring_exchange"},
382419
"fc1_dgrad": {"method": "ring_exchange"},
383420
}
421+
422+
quantization_modes = [
423+
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
424+
]
425+
if opts.first_last_layers_bf16 and opts.fp8:
426+
quantization_modes.append(UserBufferQuantizationMode.NONE)
427+
384428
te.module.base.initialize_ub(
385429
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
386430
opts.tp,
387-
use_fp8=opts.fp8,
431+
quantization_modes=quantization_modes,
388432
dtype=torch.bfloat16,
389433
bootstrap_backend=opts.bootstrap_backend,
390434
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
@@ -423,6 +467,16 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
423467
elif opts.quantization == "mxfp8":
424468
fp8_recipe = MXFP8BlockScaling()
425469

470+
layer_contexts = [
471+
(
472+
partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world)
473+
if opts.num_layers_at_start_in_bf16 <= i
474+
and i < (opts.num_layers - opts.num_layers_at_end_in_bf16)
475+
else nullcontext
476+
)
477+
for i in range(opts.num_layers)
478+
]
479+
426480
# Prepare random input tensors
427481
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
428482
test_x.retain_grad()
@@ -435,14 +489,13 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
435489
# Execute fwd/bwd and collect tensors to test
436490
def run_fwd_bwd(model, x):
437491
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
438-
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
439-
y = model(x)
440-
if isinstance(y, tuple):
441-
out, *_ = y
442-
else:
443-
out = y
444-
loss = out.sum()
445-
loss.backward()
492+
y = model(x, layer_contexts)
493+
if isinstance(y, tuple):
494+
out, *_ = y
495+
else:
496+
out = y
497+
loss = out.sum()
498+
loss.backward()
446499
return out
447500

448501
torch_rng_state = torch.get_rng_state()

tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,13 @@ def main() -> None:
506506
model_config.num_heads * model_config.head_dim,
507507
],
508508
torch.distributed.get_world_size(group),
509-
use_fp8=model_config.quantization is not None,
509+
quantization_modes=[
510+
(
511+
UserBufferQuantizationMode.FP8
512+
if model_config.quantization is not None
513+
else UserBufferQuantizationMode.NONE
514+
)
515+
],
510516
dtype=model_config.dtype,
511517
bootstrap_backend=bootstrap_backend,
512518
ub_cfgs=userbuffer_configs,

tests/pytorch/test_fusible_ops.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,43 +1749,65 @@ def test_constant_scale(
17491749
torch.testing.assert_close(y_test, y_ref, **tols)
17501750
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
17511751

1752-
@pytest.mark.parametrize("prob", (0.1, 0.5, 0.75))
1752+
@pytest.mark.parametrize("prob", (0.0625, 0.5, 0.75))
17531753
@pytest.mark.parametrize("is_training", (True, False))
1754-
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16)))
1754+
@pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
1755+
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16), (128, 128)))
17551756
@pytest.mark.parametrize("dtype", _dtypes)
17561757
def test_dropout(
17571758
self,
17581759
*,
17591760
prob: float,
17601761
is_training: bool,
1762+
quantization: Optional[str],
17611763
shape: Iterable[int],
17621764
dtype: torch.dtype,
17631765
device: torch.device = "cuda",
17641766
):
17651767

1768+
# Skip invalid configurations
1769+
quantized_input = quantization is not None
1770+
maybe_skip_quantization(quantization, dims=shape, device=device)
1771+
17661772
# Random data
1767-
x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
1768-
x_test = x_ref.clone().requires_grad_()
1769-
dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
1770-
dy_test = dy_ref.clone()
1773+
# Note: Shift values to make sure inputs are non-zero
1774+
x_ref, x_test = make_reference_and_test_tensors(
1775+
shape,
1776+
quantization=quantization,
1777+
test_dtype=dtype,
1778+
test_device=device,
1779+
test_is_quantized=quantized_input,
1780+
)
1781+
with torch.no_grad():
1782+
x_test += 1
1783+
x_ref.copy_(x_test)
1784+
dy_ref, dy_test = make_reference_and_test_tensors(
1785+
shape,
1786+
test_dtype=dtype,
1787+
test_device=device,
1788+
requires_grad=False,
1789+
)
17711790

17721791
# Apply dropout
17731792
op = te_ops.Dropout(prob)
17741793
if is_training:
17751794
op.train()
17761795
else:
17771796
op.eval()
1778-
y = op(x_test)
1779-
y.backward(dy_test)
1797+
y_test = op(x_test)
1798+
y_test.backward(dy_test)
17801799

17811800
# Check values
1801+
y_test = y_test.to(dtype=torch.float64, device="cpu")
1802+
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
17821803
if is_training:
1783-
mask = ((y != 0) / (1 - prob)).to(dtype=dtype)
1784-
torch.testing.assert_close(y, x_ref * mask)
1785-
torch.testing.assert_close(x_test.grad, dy_ref * mask)
1804+
tols = dtype_tols(dtype)
1805+
mask = ((y_test != 0) / (1 - prob)).to(dtype=dtype)
1806+
torch.testing.assert_close(y_test, x_ref * mask, **tols)
1807+
torch.testing.assert_close(dx_test, dy_ref * mask, **tols)
17861808
else:
1787-
torch.testing.assert_close(y, x_ref, rtol=0, atol=0)
1788-
torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0)
1809+
torch.testing.assert_close(y_test, x_ref, rtol=0, atol=0)
1810+
torch.testing.assert_close(dx_test, dy_ref, rtol=0, atol=0)
17891811

17901812
# Hypothesis testing for number of zeros
17911813
# Note: A Bernoulli random variable with probability p has
@@ -1797,9 +1819,11 @@ def test_dropout(
17971819
# p-value is less than 1% and we assume that the dropout
17981820
# distribution is incorrect.
17991821
if is_training:
1800-
prob_observed = 1 - torch.count_nonzero(y).item() / y.numel()
1801-
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel())
1802-
assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval"
1822+
prob_observed = 1 - torch.count_nonzero(y_test).item() / y_test.numel()
1823+
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y_test.numel())
1824+
assert (
1825+
abs(z_score) < 2.5758
1826+
), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})"
18031827

18041828

18051829
class TestFusedOps:

tests/pytorch/test_numerics.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,18 @@
122122

123123

124124
def is_fused_attn_available(
125-
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
125+
config: ModelConfig,
126+
dtype: torch.dtype,
127+
qkv_layout="bshd_bshd_bshd",
128+
is_training=True,
129+
deterministic=False,
126130
):
127131
_, _, fused_attn_backends = get_available_attention_backends(
128132
config,
129133
qkv_dtype=dtype,
130134
qkv_layout=qkv_layout,
131135
is_training=is_training,
136+
deterministic=deterministic,
132137
)
133138
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
134139

@@ -839,7 +844,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
839844
@pytest.mark.parametrize("model", ["126m"])
840845
def test_gpt_checkpointing(dtype, bs, model):
841846
config = model_configs[model]
842-
if not is_fused_attn_available(config, dtype):
847+
if not is_fused_attn_available(config, dtype, deterministic=True):
843848
pytest.skip("No attention backend available.")
844849
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
845850
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
@@ -887,7 +892,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
887892
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
888893
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
889894
config = model_configs[model]
890-
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
895+
if not is_fused_attn_available(
896+
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
897+
):
891898
pytest.skip("No attention backend available.")
892899

893900
te_gpt = TransformerLayer(
@@ -1000,7 +1007,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
10001007
@pytest.mark.parametrize("mask_type", mask_types)
10011008
def test_mha_accuracy(dtype, bs, model, mask_type):
10021009
config = model_configs[model]
1003-
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
1010+
if not is_fused_attn_available(
1011+
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
1012+
):
10041013
pytest.skip("No attention backend available.")
10051014

10061015
te_mha = MultiheadAttention(

tests/pytorch/test_onnx_export.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
fp8_recipes.append(recipe.MXFP8BlockScaling())
6666
if fp8_available:
6767
fp8_recipes.append(recipe.DelayedScaling())
68+
fp8_recipes.append(recipe.Float8CurrentScaling())
6869
fp8_recipes.append(None)
6970

7071
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
@@ -81,11 +82,11 @@
8182
],
8283
outputs=[PyCustomOpDef.dt_uint8],
8384
)
84-
def trt_fp8_quantize(t, scale):
85+
def trt_fp8_quantize(t, scale_inv):
8586
"""FP8 quantization extension for ONNX Runtime."""
8687
x = torch.from_numpy(t).cuda()
8788
q = te.tensor.float8_tensor.Float8Quantizer(
88-
scale=1 / torch.from_numpy(scale).cuda(),
89+
scale=1 / torch.from_numpy(scale_inv).cuda(),
8990
amax=torch.zeros([1]).cuda(),
9091
fp8_dtype=tex.DType.kFloat8E4M3,
9192
)
@@ -101,11 +102,11 @@ def trt_fp8_quantize(t, scale):
101102
],
102103
outputs=[PyCustomOpDef.dt_float],
103104
)
104-
def trt_fp8_dequantize(t, scale):
105+
def trt_fp8_dequantize(t, scale_inv):
105106
"""FP8 dequantization extension for ONNX Runtime."""
106107
x = torch.from_numpy(t).cuda()
107108
q = te.tensor.float8_tensor.Float8Quantizer(
108-
scale=1 / torch.from_numpy(scale).cuda(),
109+
scale=1 / torch.from_numpy(scale_inv).cuda(),
109110
amax=torch.zeros([1]).cuda(),
110111
fp8_dtype=tex.DType.kFloat8E4M3,
111112
)
@@ -593,7 +594,9 @@ def _test_export_layernorm_linear(
593594
fname,
594595
inp,
595596
model,
596-
atol=1e-3,
597+
# For current scaling we use Float8Quantizer in tests + amax computed by hand,
598+
# which has slightly different numerics than Float8CurrentScalingQuantizer.
599+
atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2,
597600
is_fp8=fp8_recipe is not None,
598601
te_outputs=te_outputs,
599602
)
@@ -1150,6 +1153,11 @@ def test_trt_integration(fp8_recipe: recipe.Recipe):
11501153
ffn_hidden_size=128,
11511154
num_attention_heads=4,
11521155
).eval()
1156+
1157+
if type(fp8_recipe) == recipe.Float8CurrentScaling:
1158+
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
1159+
model = te.LayerNormMLP(128, 128)
1160+
11531161
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
11541162

11551163
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):

0 commit comments

Comments
 (0)