Skip to content

Commit 7fa0f55

Browse files
vthumbe1503pre-commit-ci[bot]yaox12timmoon10Jonathan Mitchell
authored
[Pytorch] Support for Swiglu Activation used in GPT OSS (#2161)
* Test working as I think it should work Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> revert accidental change Signed-off-by: Varun Thumbe <[email protected]> Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> fix merge conflict Signed-off-by: Varun Thumbe <[email protected]> bug: missed a } in the code Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold <[email protected]> * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold <[email protected]> * Format and lint Signed-off-by: Jeremy Berchtold <[email protected]> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold <[email protected]> * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold <[email protected]> * Update test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe <[email protected]> [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> missed a quant code removal Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> minor bug fix Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> minor code cleanup Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> minor cosmetics Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> Address review comment Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> minor comment update Signed-off-by: Varun Thumbe <[email protected]> Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe <[email protected]> [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani <[email protected]> * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> Add check for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani <[email protected]> * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani <[email protected]> * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> fix linting error Signed-off-by: Varun Thumbe <[email protected]> * initial draft of changes to get GPT oss based swiglu integrated, gated kernels needs to be fixed Signed-off-by: Varun Thumbe <[email protected]> * redundant implementation for the pytorch to te hook up, refactoring to be done later Signed-off-by: Varun Thumbe <[email protected]> * all gated kernels modified, pytest working for oss swiglu Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> * fix the merge conflict Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz <[email protected]> * fix for fp8 blockwise recipe Signed-off-by: zhongboz <[email protected]> * resolve comments Signed-off-by: zhongboz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] fix cross entropy vanishing gradients (#2139) * fix cross entropy Signed-off-by: Casper <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper <[email protected]> * fix comments Signed-off-by: Casper <[email protected]> * fix: few more style issues Signed-off-by: Casper <[email protected]> * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu <[email protected]> * format Signed-off-by: Hongbin Liu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu <[email protected]> Co-authored-by: Hongbin Liu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov <[email protected]> * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov <[email protected]> * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold <[email protected]> * Fix test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> * Lint Signed-off-by: Jeremy Berchtold <[email protected]> * Fix comments Signed-off-by: Jeremy Berchtold <[email protected]> * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold <[email protected]> * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold <[email protected]> * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold <[email protected]> * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold <[email protected]> * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold <[email protected]> * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold <[email protected]> * Fix tests Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Update list of authorized CI users (#2152) Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> a bit of cleanup Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> * accidentally had removed some activations, minor bug in the templated function Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> * parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe <[email protected]> 1757373536 +0000 committer Varun Thumbe <[email protected]> 1758262513 +0000 parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe <[email protected]> 1757373536 +0000 committer Varun Thumbe <[email protected]> 1758262476 +0000 parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe <[email protected]> 1757373536 +0000 committer Varun Thumbe <[email protected]> 1758262304 +0000 merge conflict Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <[email protected]> [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani <[email protected]> * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> Add check for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani <[email protected]> * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani <[email protected]> * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz <[email protected]> * fix for fp8 blockwise recipe Signed-off-by: zhongboz <[email protected]> * resolve comments Signed-off-by: zhongboz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> [PyTorch] fix cross entropy vanishing gradients (#2139) * fix cross entropy Signed-off-by: Casper <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper <[email protected]> * fix comments Signed-off-by: Casper <[email protected]> * fix: few more style issues Signed-off-by: Casper <[email protected]> * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu <[email protected]> * format Signed-off-by: Hongbin Liu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu <[email protected]> Co-authored-by: Hongbin Liu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov <[email protected]> * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov <[email protected]> * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Tim Moon <[email protected]> [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold <[email protected]> * Fix test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> * Lint Signed-off-by: Jeremy Berchtold <[email protected]> * Fix comments Signed-off-by: Jeremy Berchtold <[email protected]> * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold <[email protected]> * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold <[email protected]> * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold <[email protected]> * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold <[email protected]> * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold <[email protected]> * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold <[email protected]> * Fix tests Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen <[email protected]> Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen <[email protected]> Update list of authorized CI users (#2152) Signed-off-by: Tim Moon <[email protected]> Fused RoPE with combined QKV input. (#2122) * Fused RoPE with combined QKV input. Initial commit for Dropout with 8-bit RNG Fix documentation Initial commit for Fused QKV RoPE WIP Initial tests passing Enable rotary percent and margin Enable CP2, start_positions, interleaved Cleanup test Revert "Fix documentation" This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715. Revert "Initial commit for Dropout with 8-bit RNG" This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca. Cleanup. Minor cleanup Signed-off-by: Vasudevan Rengasamy <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy <[email protected]> * Optimize kernels Signed-off-by: Vasudevan Rengasamy <[email protected]> * Misc. Cleanup Signed-off-by: Vasudevan Rengasamy <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy <[email protected]> * Optimize kernel performance Signed-off-by: Vasudevan Rengasamy <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy <[email protected]> * Move fused_qkv_rope test to test_fused_rope.py Signed-off-by: Vasudevan Rengasamy <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * apply shared memory optimization to separate fused rope kernels Signed-off-by: Xin Yao <[email protected]> * fix lint Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Vasudevan Rengasamy <[email protected]> Signed-off-by: Xin Yao <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao <[email protected]> Co-authored-by: Tim Moon <[email protected]> * accidentally removed the copyright Signed-off-by: Varun Thumbe <[email protected]> * fix linting issue Signed-off-by: Varun Thumbe <[email protected]> * minor issue in comments Signed-off-by: Varun Thumbe <[email protected]> * Commit is for another PR Signed-off-by: vthumbe1503 <[email protected]> * revert changes since this belongs to another PR Signed-off-by: vthumbe1503 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert change back since belongs to another PR Signed-off-by: vthumbe1503 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changes belong to another PR Signed-off-by: vthumbe1503 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert changes here Signed-off-by: vthumbe1503 <[email protected]> Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162) * add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion Signed-off-by: tongliu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tongliu <[email protected]> Co-authored-by: tongliu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [JAX] Scale swizzling via JAX transpose op (#2163) * add swizzle in jax Signed-off-by: Phuong Nguyen <[email protected]> * added outer_impl Signed-off-by: Phuong Nguyen <[email protected]> * clean up FFI Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Extract cpp distributed tests into a separate project (#2165) * Extract cpp distributed tests into a separate project Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove obsolete exclusion Signed-off-by: Vladimir Cherepanov <[email protected]> * Run L1_cpp_distributed tests if at least 4 GPUs Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129) * test - adds unit test for cp utilities and the utilites Signed-off-by: Jonathan Mitchell <[email protected]> * assert line change Signed-off-by: Jonathan Mitchell <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jonathan Mitchell <[email protected]> Co-authored-by: Jonathan Mitchell <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sudhakar Singh <[email protected]> * address review comments Signed-off-by: Varun Thumbe <[email protected]> * cleanup Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix linting error Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [PyTorch Debug] Fix issue with negative underflow% stat. (#2107) * fix underflows log issue Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Address review comments, fix mxfp8 kernel bug: was not passing clamped swiglu parameter correctly Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> Lower precision gated-act to accelerate FP8 current-scaling. (#2153) * Applying the original precision as N…
1 parent 25252e9 commit 7fa0f55

File tree

14 files changed

+459
-122
lines changed

14 files changed

+459
-122
lines changed

tests/pytorch/test_fusible_ops.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,80 @@ def test_swiglu(
17361736
torch.testing.assert_close(y_test, y_ref, **tols)
17371737
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
17381738

1739+
@pytest.mark.parametrize("dtype", _dtypes)
1740+
@pytest.mark.parametrize("quantization", _quantization_list)
1741+
@pytest.mark.parametrize("quantize_forward", (False, True))
1742+
@pytest.mark.parametrize("quantize_backward", (False, True))
1743+
def test_clamped_swiglu(
1744+
self,
1745+
*,
1746+
out_shape: Iterable[int] = (32, 32),
1747+
dtype: torch.dtype,
1748+
device: torch.device = "cuda",
1749+
quantization: Optional[str],
1750+
quantize_forward: bool,
1751+
quantize_backward: bool,
1752+
limit: float = 0.75,
1753+
alpha: float = 1.702,
1754+
):
1755+
# Test SwiGLU variant used in GPT OSS.
1756+
# Tensor dimensions
1757+
in_shape = list(out_shape)
1758+
in_shape[-1] *= 2
1759+
1760+
# Skip invalid configurations
1761+
quantized_compute = quantization is not None
1762+
if not quantized_compute and (quantize_forward or quantize_backward):
1763+
pytest.skip("Quantization scheme has not been provided")
1764+
maybe_skip_quantization(quantization, dims=in_shape, device=device)
1765+
1766+
# Random data
1767+
x_ref, x_test = make_reference_and_test_tensors(
1768+
in_shape,
1769+
test_dtype=dtype,
1770+
test_device=device,
1771+
)
1772+
dy_ref, dy_test = make_reference_and_test_tensors(
1773+
out_shape,
1774+
test_dtype=dtype,
1775+
test_device=device,
1776+
requires_grad=False,
1777+
)
1778+
1779+
# Plain PyTorch implementation
1780+
x_glu, x_linear = x_ref.chunk(2, dim=-1)
1781+
x_glu = x_glu.clamp(min=None, max=limit)
1782+
x_linear = x_linear.clamp(min=-limit, max=limit)
1783+
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
1784+
y_ref = out_glu * (x_linear + 1)
1785+
y_ref.backward(dy_ref)
1786+
1787+
# Implementation with fusible operation
1788+
recipe = make_recipe(quantization)
1789+
1790+
forward = te_ops.Sequential(
1791+
te_ops.Quantize(forward=False, backward=quantize_backward),
1792+
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha),
1793+
te_ops.Quantize(forward=quantize_forward, backward=False),
1794+
)
1795+
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1796+
y_test = forward(x_test)
1797+
1798+
y_test.backward(dy_test)
1799+
1800+
# Expected numerical error
1801+
tols = dtype_tols(dtype)
1802+
if quantized_compute and quantization == "nvfp4":
1803+
tols = dtype_tols(tex.DType.kFloat4E2M1)
1804+
elif quantized_compute:
1805+
tols = dtype_tols(tex.DType.kFloat8E4M3)
1806+
1807+
# Check results
1808+
y_test = y_test.to(dtype=torch.float64, device="cpu")
1809+
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
1810+
torch.testing.assert_close(y_test, y_ref, **tols)
1811+
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
1812+
17391813
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
17401814
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
17411815
@pytest.mark.parametrize("dtype", _dtypes)

transformer_engine/common/activation/activation_template.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
5151
}
5252

5353
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
54-
void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
54+
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
5555
using namespace detail;
5656
constexpr bool IS_DGATED = false;
5757
constexpr NVTETensor grad = nullptr;
58-
59-
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, stream);
58+
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
6059
}
6160

6261
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
6362
ComputeType (*DActOP)(ComputeType, const Param &)>
64-
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
63+
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
6564
cudaStream_t stream) {
6665
using namespace detail;
6766
constexpr bool IS_DGATED = true;
68-
69-
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, stream);
67+
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
7068
}
7169

7270
} // namespace transformer_engine

transformer_engine/common/activation/gelu.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
2323
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
2424
NVTE_API_CALL(nvte_geglu);
2525
using namespace transformer_engine;
26-
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
26+
Empty e = {};
27+
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, e, stream);
2728
}
2829

2930
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
3031
cudaStream_t stream) {
3132
NVTE_API_CALL(nvte_dgeglu);
3233
using namespace transformer_engine;
33-
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, stream);
34+
Empty e = {};
35+
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, e, stream);
3436
}
3537

3638
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
@@ -49,12 +51,14 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
4951
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
5052
NVTE_API_CALL(nvte_qgeglu);
5153
using namespace transformer_engine;
52-
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
54+
Empty e = {};
55+
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, e, stream);
5356
}
5457

5558
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
5659
cudaStream_t stream) {
5760
NVTE_API_CALL(nvte_dqgeglu);
5861
using namespace transformer_engine;
59-
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, stream);
62+
Empty e = {};
63+
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, e, stream);
6064
}

transformer_engine/common/activation/relu.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
2323
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
2424
NVTE_API_CALL(nvte_reglu);
2525
using namespace transformer_engine;
26-
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
26+
Empty e = {};
27+
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, e, stream);
2728
}
2829

2930
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
3031
cudaStream_t stream) {
3132
NVTE_API_CALL(nvte_dreglu);
3233
using namespace transformer_engine;
33-
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, stream);
34+
Empty e = {};
35+
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, e, stream);
3436
}
3537

3638
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
@@ -49,12 +51,14 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
4951
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
5052
NVTE_API_CALL(nvte_sreglu);
5153
using namespace transformer_engine;
52-
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
54+
Empty e = {};
55+
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, e, stream);
5356
}
5457

5558
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
5659
cudaStream_t stream) {
5760
NVTE_API_CALL(nvte_dsreglu);
5861
using namespace transformer_engine;
59-
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, stream);
62+
Empty e = {};
63+
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, e, stream);
6064
}

transformer_engine/common/activation/swiglu.cu

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,31 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
2323
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
2424
NVTE_API_CALL(nvte_swiglu);
2525
using namespace transformer_engine;
26-
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
26+
Empty e = {};
27+
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, e, stream);
2728
}
2829

2930
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
3031
cudaStream_t stream) {
3132
NVTE_API_CALL(nvte_dswiglu);
3233
using namespace transformer_engine;
33-
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, stream);
34+
Empty e = {};
35+
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, e, stream);
36+
}
37+
38+
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
39+
cudaStream_t stream) {
40+
NVTE_API_CALL(nvte_clamped_swiglu);
41+
using namespace transformer_engine;
42+
ClampedSwiGLUParam param = {limit, alpha};
43+
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
44+
}
45+
46+
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
47+
float limit, float alpha, cudaStream_t stream) {
48+
NVTE_API_CALL(nvte_clamped_dswiglu);
49+
using namespace transformer_engine;
50+
ClampedSwiGLUParam param = {limit, alpha};
51+
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
52+
grad, input, output, param, stream);
3453
}

transformer_engine/common/include/transformer_engine/activation.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
173173
*/
174174
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
175175

176+
/*! \brief Computes the gated Swish activation of the input used in GPT OSS.
177+
*
178+
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
179+
* This Gated activation has two differences compared to the original SwiGLU
180+
* 1. Both gate and pre-activations are clipped based on parameter limit.
181+
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
182+
* by original GELU paper https://arxiv.org/pdf/1606.08415
183+
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
184+
* the block quantization (MXFP8) of the specified shape of the block will be used.
185+
*
186+
* \param[in] input Input tensor of shape [N, H * 2].
187+
* \param[in,out] output Output tensor of shape [N, H].
188+
* It computes Act(input[N, :H]) x input[N, H:]
189+
* \param[in] limit Clipping limits for gate and pre-activation.
190+
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
191+
* \param[in] stream CUDA stream used for the operation.
192+
*/
193+
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
194+
cudaStream_t stream);
195+
176196
/*! \brief Computes the gated ReLU activation of the input.
177197
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
178198
* the block quantization (MXFP8) of the specified shape of the block will be used.
@@ -230,6 +250,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
230250
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
231251
cudaStream_t stream);
232252

253+
/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS.
254+
*
255+
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
256+
* This activation has two differences compared to the original SwiGLU
257+
* 1. Both gate and pre-activations are clipped based on parameter limit.
258+
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
259+
* by original GELU paper https://arxiv.org/pdf/1606.08415
260+
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
261+
* the block quantization (MXFP8) of the specified shape of the block will be used.
262+
*
263+
* \param[in] grad Incoming gradient of shape [N, H].
264+
* \param[in] input Forward input tensor of shape [N, H * 2].
265+
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
266+
* \param[in] limit Clipping limits for gate and pre-activation.
267+
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
268+
* \param[in] stream CUDA stream used for the operation.
269+
*/
270+
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
271+
float limit, float alpha, cudaStream_t stream);
272+
233273
/*! \brief Computes the gated ReLU activation gradient.
234274
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
235275
* the block quantization (MXFP8) of the specified shape of the block will be used.

0 commit comments

Comments
 (0)