Skip to content

Commit da33d26

Browse files
Merge branch 'main' into jm/context_parallel_utilities
2 parents 3aaba86 + 5f2b831 commit da33d26

File tree

125 files changed

+4794
-941
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+4794
-941
lines changed

.github/workflows/trigger-ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ jobs:
5757
|| github.actor == 'tdophung'
5858
|| github.actor == 'vthumbe1503'
5959
|| github.actor == 'janekb04'
60+
|| github.actor == 'shengfangd'
6061
)
6162
steps:
6263
- name: Check if comment is issued by authorized person

docs/api/pytorch.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pyTorch
4949

5050
.. autoapifunction:: transformer_engine.pytorch.moe_permute
5151

52-
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
52+
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
5353

5454
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
5555

@@ -62,3 +62,6 @@ pyTorch
6262
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
6363

6464
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
65+
66+
.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
67+
:members: FP8, NONE

docs/examples/onnx/onnx_export.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"<b>Note:</b>\n",
1212
"\n",
13-
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.\n",
13+
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling, FP8 current scaling and MXFP8.\n",
1414
"\n",
1515
"</div>\n",
1616
"\n",

examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,13 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)
263263
te.module.base.initialize_ub(
264264
[batched_size, hidden_size],
265265
tp_size,
266-
use_fp8=opts.fp8,
266+
quantization_modes=[
267+
(
268+
te.module.base.UserBufferQuantizationMode.FP8
269+
if opts.fp8
270+
else te.module.base.UserBufferQuantizationMode.NONE
271+
)
272+
],
267273
dtype=torch.bfloat16,
268274
bootstrap_backend=opts.bootstrap_backend,
269275
)

qa/L0_cppunittest/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp
1717
cmake -GNinja -Bbuild .
1818
cmake --build build
1919
export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS))
20-
ctest --test-dir build -j$NUM_PARALLEL_JOBS
20+
ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)'

qa/L1_cpp_distributed/test.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
set -e
6+
7+
# Find TE
8+
: ${TE_PATH:=/opt/transformerengine}
9+
TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}')
10+
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
11+
12+
cd $TE_PATH/tests/cpp
13+
cmake -GNinja -S. -Bbuild
14+
cmake --build build
15+
mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm

qa/L1_jax_distributed_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ set -xe
99
mkdir -p "$XML_LOG_DIR"
1010

1111
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
12+
SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh

setup.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
"""Installation script."""
66

7+
from importlib import metadata
78
import os
89
import time
910
from pathlib import Path
@@ -16,6 +17,7 @@
1617
from build_tools.te_version import te_version
1718
from build_tools.utils import (
1819
cuda_archs,
20+
cuda_version,
1921
get_frameworks,
2022
remove_dups,
2123
)
@@ -66,6 +68,18 @@ def setup_common_extension() -> CMakeExtension:
6668
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
6769
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
6870

71+
if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
72+
cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
73+
cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
74+
f"nvidia-cublasmp-cu{cuda_version()[0]}"
75+
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
76+
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
77+
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
78+
f"nvidia-nvshmem-cu{cuda_version()[0]}"
79+
).locate_file("nvidia/nvshmem")
80+
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
81+
print("CMAKE_FLAGS:", cmake_flags[-2:])
82+
6983
# Add custom CMake arguments from environment variable
7084
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
7185
if nvte_cmake_extra_args:

tests/cpp/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_
3737
message(STATUS "Found transformer_engine library: ${TE_LIB}")
3838
include_directories(../../transformer_engine/common/include)
3939
include_directories(../../transformer_engine/common)
40+
include_directories(../../transformer_engine)
4041
include_directories(${CMAKE_SOURCE_DIR})
4142

4243
find_package(CUDAToolkit REQUIRED)
4344
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
4445

46+
add_subdirectory(comm_gemm)
4547
add_subdirectory(operator)
4648
add_subdirectory(util)

tests/cpp/comm_gemm/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
add_executable(test_comm_gemm
6+
test_comm_gemm.cu
7+
../test_common.cu)
8+
9+
find_package(OpenMP REQUIRED)
10+
find_package(MPI REQUIRED)
11+
find_library(NCCL_LIB
12+
NAMES nccl libnccl
13+
PATH_SUFFIXES lib
14+
REQUIRED)
15+
target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include)
16+
target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX)
17+
18+
include(GoogleTest)
19+
gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600)

0 commit comments

Comments
 (0)