Skip to content

Commit ba37529

Browse files
FP8 Output Quantization for GEMM (#2123)
* Test working as I think it should work Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> * revert accidental change Signed-off-by: Varun Thumbe <[email protected]> Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> fix merge conflict Signed-off-by: Varun Thumbe <[email protected]> bug: missed a } in the code Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold <[email protected]> * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold <[email protected]> * Format and lint Signed-off-by: Jeremy Berchtold <[email protected]> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold <[email protected]> * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold <[email protected]> * Update test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe <[email protected]> [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <[email protected]> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> Signed-off-by: Varun Thumbe <[email protected]> mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> missed a quant code removal Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> minor bug fix Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <[email protected]> * Test fixure Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix axes Signed-off-by: Vladimir Cherepanov <[email protected]> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <[email protected]> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor Signed-off-by: Vladimir Cherepanov <[email protected]> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-RS Signed-off-by: Vladimir Cherepanov <[email protected]> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <[email protected]> * Fixes Signed-off-by: Vladimir Cherepanov <[email protected]> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <[email protected]> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <[email protected]> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <[email protected]> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <[email protected]> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak scaling Signed-off-by: Vladimir Cherepanov <[email protected]> * Amax ptr Signed-off-by: Vladimir Cherepanov <[email protected]> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * Cleanup Signed-off-by: Vladimir Cherepanov <[email protected]> * Bias tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix bias test Signed-off-by: Vladimir Cherepanov <[email protected]> * Aux, saving... Signed-off-by: Vladimir Cherepanov <[email protected]> * aux_ld Signed-off-by: Vladimir Cherepanov <[email protected]> * A fix Signed-off-by: Vladimir Cherepanov <[email protected]> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <[email protected]> * Set scale inv Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak tests Signed-off-by: Vladimir Cherepanov <[email protected]> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <[email protected]> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <[email protected]> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <[email protected]> * More test config Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <[email protected]> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <[email protected]> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <[email protected]> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <[email protected]> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <[email protected]> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <[email protected]> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <[email protected]> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove leftover code Signed-off-by: Vladimir Cherepanov <[email protected]> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <[email protected]> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <[email protected]> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <[email protected]> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <[email protected]> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> * Add license Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Vladimir Cherepanov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold <[email protected]> * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold <[email protected]> * Format and lint Signed-off-by: Jeremy Berchtold <[email protected]> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold <[email protected]> * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold <[email protected]> * Update test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> feat: Add support for multiple quantization modes in the UB communicators (#2043) [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> minor code cleanup Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor cosmetics Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Address review comment Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor comment update Signed-off-by: Varun Thumbe <[email protected]> Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <[email protected]> minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe <[email protected]> [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani <[email protected]> * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> Add check for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani <[email protected]> * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani <[email protected]> * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> fix linting error Signed-off-by: Varun Thumbe <[email protected]> [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz <[email protected]> * fix for fp8 blockwise recipe Signed-off-by: zhongboz <[email protected]> * resolve comments Signed-off-by: zhongboz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> address review comments Signed-off-by: Varun Thumbe <[email protected]> * Update test_multi_process_distributed_grouped_gemm.py change accidentally added while merging Signed-off-by: vthumbe1503 <[email protected]> * Update dense.py change accidentally added while merging Signed-off-by: vthumbe1503 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address revie comments Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bug solved: delayed scaling quantization with mxfp8 inputs didnt work Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the unit test error Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * just to trigger ci Signed-off-by: Varun Thumbe <[email protected]> * address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> * fix merge conflict Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Varun Thumbe <[email protected]> Signed-off-by: vthumbe1503 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c221909 commit ba37529

File tree

4 files changed

+125
-64
lines changed

4 files changed

+125
-64
lines changed

tests/pytorch/test_numerics.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,21 @@
3939
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
4040
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
4141
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
42-
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
42+
from transformer_engine.pytorch.tensor.float8_tensor import (
43+
Float8Quantizer,
44+
Float8CurrentScalingQuantizer,
45+
)
46+
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
4347
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
4448
from transformer_engine.pytorch.utils import get_device_compute_capability
4549
from transformer_engine.common import recipe
4650
import transformer_engine_torch as tex
4751
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
4852

53+
4954
# Only run FP8 tests on supported devices.
5055
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
51-
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
56+
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
5257
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
5358

5459
sm_80plus = get_device_compute_capability() >= (8, 0)
@@ -2607,6 +2612,73 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
26072612
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
26082613

26092614

2615+
@pytest.mark.parametrize("N", [32])
2616+
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
2617+
@pytest.mark.parametrize(
2618+
"input_quantizer",
2619+
[
2620+
Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"),
2621+
MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3),
2622+
],
2623+
)
2624+
@pytest.mark.parametrize(
2625+
"out_quantizer",
2626+
[
2627+
Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"),
2628+
MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3),
2629+
Float8Quantizer(
2630+
torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), tex.DType.kFloat8E4M3
2631+
),
2632+
],
2633+
)
2634+
def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_quantizer):
2635+
# For MXFP8 and CurrentScaling, below unfused quantization should happen
2636+
# FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output
2637+
# Skip invalid configurations
2638+
is_mxfp8_needed = isinstance(input_quantizer, MXFP8Quantizer) or isinstance(
2639+
out_quantizer, MXFP8Quantizer
2640+
)
2641+
if not fp8_available:
2642+
pytest.skip(reason_for_no_fp8)
2643+
if is_mxfp8_needed and not mxfp8_available:
2644+
pytest.skip(reason_for_no_mxfp8)
2645+
inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype))
2646+
weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype))
2647+
outp_type = torch.float32
2648+
quantized_out, *_ = general_gemm(
2649+
weight_fp8,
2650+
inp_fp8,
2651+
get_workspace(),
2652+
outp_type,
2653+
quantization_params=out_quantizer,
2654+
bias=None,
2655+
use_split_accumulator=False,
2656+
)
2657+
2658+
out, *_ = general_gemm(
2659+
weight_fp8,
2660+
inp_fp8,
2661+
get_workspace(),
2662+
outp_type,
2663+
quantization_params=None,
2664+
bias=None,
2665+
use_split_accumulator=False,
2666+
)
2667+
expected_quantized_out = out_quantizer(out)
2668+
2669+
# Match results again Pytorch GEMM and allow for quantization tolerance
2670+
pytorch_out = torch.matmul(
2671+
inp_fp8.dequantize().to(torch.float64),
2672+
torch.transpose(weight_fp8.dequantize().to(torch.float64), 0, 1),
2673+
)
2674+
fp8_tols = dict(rtol=0.125, atol=0.0675)
2675+
torch.testing.assert_close(
2676+
pytorch_out.to(outp_type), expected_quantized_out.dequantize(), **fp8_tols
2677+
)
2678+
# Match results between quantization happening inside vs outside general_gemm
2679+
torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize())
2680+
2681+
26102682
@pytest.mark.parametrize(
26112683
"shape",
26122684
[

transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,14 +579,19 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
579579
"Input and output_t must have the same shape for columnwise non-transpose case.");
580580
}
581581
}
582-
583-
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype.");
582+
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
583+
// output may not be defined if rowwise quantization is not needed.
584+
NVTE_CHECK(output.dtype == output_t.dtype,
585+
"output and output_t need to have the same dtype.");
586+
}
584587
NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2.");
585588
bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
586589
size_t scale_t_k = scale_inv_t.shape[1];
587590
scale_t_stride_x = columnwise_compact ? 1 : scale_t_k;
588591
scale_t_stride_y = columnwise_compact ? scale_t_k : 1;
589592
}
593+
auto output_dtype =
594+
rowwise_option != FP8BlockwiseRowwiseOption::NONE ? output.dtype : output_t.dtype;
590595

591596
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
592597
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
@@ -597,7 +602,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
597602
input.dtype, InputType,
598603

599604
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
600-
output.dtype, OutputType,
605+
output_dtype, OutputType,
601606

602607
dim3 grid(num_blocks_x, num_blocks_y, 1);
603608

transformer_engine/pytorch/csrc/extensions/gemm.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
9393
bool use_split_accumulator, CommOverlapCore* comm_overlap,
9494
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
9595
bool bulk_overlap, float alpha, std::optional<float> beta) {
96+
using namespace transformer_engine::pytorch::detail;
97+
9698
// Input tensors
9799
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
98100
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
@@ -123,10 +125,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
123125
"into D tensor. Beta has nothing to be applied to.");
124126
}
125127

128+
DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype();
126129
// Output tensor
127130
TensorWrapper D_tensor;
128131
if (D.is_none()) {
129-
DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype();
130132
std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer);
131133
} else {
132134
D_tensor = makeTransformerEngineTensor(D, quantizer);
@@ -139,12 +141,35 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
139141
}
140142
}
141143

144+
// maintain unquantized tensor in case we need unfused quantization support.
145+
TensorWrapper unquantized_D_tensor;
146+
py::object unquantized_out;
147+
// Unfused quantization is needed in the following cases
148+
// 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that)
149+
// 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling,
150+
// GEMM Output needs to be in BF16, to allow for unfused quantization)
151+
bool unfused_quantization_needed = !quantizer.is_none();
152+
if (low_precision) {
153+
// At the moment, only use-case for fused GEMM:
154+
// Delayed scaling quantizer with per-tensor scaling inputs
155+
bool is_per_tensor_scaling_input = IsFloat8Tensor(A.ptr()) || IsFloat8Tensor(B.ptr());
156+
if (IsFloat8Quantizers(quantizer.ptr()) && is_per_tensor_scaling_input)
157+
unfused_quantization_needed = false;
158+
}
159+
160+
if (unfused_quantization_needed) {
161+
NoneQuantizer q{none};
162+
std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype);
163+
}
164+
TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor;
165+
142166
// Bias tensor
143167
TensorWrapper bias_tensor;
144168
MaybeTensor bias_grad = std::nullopt;
145169
if (bias.has_value()) {
146170
if (grad) {
147-
auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA);
171+
auto opts =
172+
torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA);
148173
bias_grad = at::empty({static_cast<int64_t>(B_shape.data[B_shape.ndim - 1])}, opts);
149174
bias_tensor = makeTransformerEngineTensor(*bias_grad);
150175
} else {
@@ -157,7 +182,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
157182

158183
// Activation input tensor
159184
MaybeTensor pre_gelu_out = std::nullopt;
160-
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
185+
DType gelu_type = low_precision ? bias_type : out_tensor.dtype();
161186
if (gelu) {
162187
if (!grad) {
163188
auto dtype = GetATenDType(gelu_type);
@@ -210,22 +235,22 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
210235
// Direct GEMM call to the correct overlap
211236
if (bulk_overlap) {
212237
NVTE_SCOPED_GIL_RELEASE({
213-
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
238+
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor,
214239
te_pre_gelu_out, te_workspace, grad, accumulate,
215240
use_split_accumulator, comm_type.value(), extra_output_tensor,
216241
main_stream);
217242
});
218243
} else if (comm_type.value() == CommOverlapType::AG) {
219244
if (comm_overlap->is_atomic_gemm()) {
220245
NVTE_SCOPED_GIL_RELEASE({
221-
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
246+
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor,
222247
bias_tensor, te_pre_gelu_out, te_workspace, grad,
223248
accumulate, use_split_accumulator,
224249
extra_output_tensor, main_stream);
225250
});
226251
} else {
227252
NVTE_SCOPED_GIL_RELEASE({
228-
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
253+
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor,
229254
bias_tensor, te_pre_gelu_out, te_workspace, grad,
230255
accumulate, use_split_accumulator, extra_output_tensor,
231256
main_stream);
@@ -234,14 +259,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
234259
} else {
235260
if (comm_overlap->is_atomic_gemm()) {
236261
NVTE_SCOPED_GIL_RELEASE({
237-
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
262+
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor,
238263
bias_tensor, te_pre_gelu_out, te_workspace, grad,
239264
accumulate, use_split_accumulator,
240265
extra_output_tensor, main_stream);
241266
});
242267
} else {
243268
NVTE_SCOPED_GIL_RELEASE({
244-
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
269+
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor,
245270
bias_tensor, te_pre_gelu_out, te_workspace, grad,
246271
accumulate, use_split_accumulator, extra_output_tensor,
247272
main_stream);
@@ -251,23 +276,27 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
251276
} else {
252277
// Launch GEMM
253278
NVTE_SCOPED_GIL_RELEASE({
254-
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(),
279+
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(),
255280
bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad,
256281
te_workspace.data(), alpha, *beta, use_split_accumulator,
257282
num_math_sms, main_stream);
258283
});
259284
}
260285
} else {
261-
if (D_tensor.numel() != 0 && !accumulate) {
262-
D_tensor.zero_(main_stream);
286+
if (out_tensor.numel() != 0 && !accumulate) {
287+
out_tensor.zero_(main_stream);
263288
}
264289
if (bias.has_value()) {
265290
if (bias->numel() != 0 && grad) {
266291
bias_grad->zero_();
267292
}
268293
}
269294
}
270-
295+
if (unfused_quantization_needed) {
296+
// Quantize the output
297+
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
298+
my_quantizer->quantize(unquantized_D_tensor, D_tensor);
299+
}
271300
// Pack outputs
272301
std::vector<py::object> out;
273302
out.emplace_back(std::move(D));

transformer_engine/pytorch/csrc/quantizer.cpp

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,6 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
9696
at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
9797
tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
9898
getTensorShape(amax));
99-
auto rowwise_data = tensor->get_rowwise_data();
100-
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
101-
102-
auto columnwise_data = tensor->get_columnwise_data();
103-
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
104-
105-
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
106-
rowwise_data.shape);
107-
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
108-
columnwise_data.shape);
10999
}
110100

111101
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
@@ -318,17 +308,6 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso
318308
at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
319309
tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
320310
getTensorShape(amax));
321-
// quantize output and its transpose
322-
auto rowwise_data = tensor->get_rowwise_data();
323-
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
324-
325-
auto columnwise_data = tensor->get_columnwise_data();
326-
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
327-
328-
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
329-
rowwise_data.shape);
330-
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
331-
columnwise_data.shape);
332311
}
333312

334313
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
@@ -562,20 +541,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
562541
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
563542
}
564543

565-
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
566-
// Change the rowwise and columnwise_data to the configured dtype.
567-
// May be a switch between E5M2 and E4M3.
568-
auto rowwise_data = tensor->get_rowwise_data();
569-
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
570-
571-
auto columnwise_data = tensor->get_columnwise_data();
572-
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
573-
574-
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
575-
rowwise_data.shape);
576-
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
577-
columnwise_data.shape);
578-
}
544+
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {}
579545

580546
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
581547
const std::vector<size_t>& shape, DType dtype) const {
@@ -917,18 +883,7 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize
917883
this->dtype = quantizer.attr("dtype").cast<DType>();
918884
}
919885

920-
void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
921-
auto rowwise_data = tensor->get_rowwise_data();
922-
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
923-
924-
auto columnwise_data = tensor->get_columnwise_data();
925-
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
926-
927-
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
928-
rowwise_data.shape);
929-
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
930-
columnwise_data.shape);
931-
}
886+
void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {}
932887

933888
std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::vector<size_t>& shape,
934889
DType dtype) const {

0 commit comments

Comments
 (0)