Skip to content

Commit 0230cd0

Browse files
zyongyeheheda12345youkaichaoLucasWilkinsonrobertgshaw2-redhat
committed
[New Model] DeepSeek-V3.2 (Rebased to Main) (#25896)
Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: youkaichao <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: NickLucche <[email protected]> Signed-off-by: Yongye Zhu <[email protected]> Signed-off-by: Barry Kang <[email protected]> Signed-off-by: Lucia Fang <[email protected]> Co-authored-by: Chen Zhang <[email protected]> Co-authored-by: youkaichao <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Co-authored-by: Robert Shaw <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Co-authored-by: yewentao256 <[email protected]> Co-authored-by: Wentao Ye <[email protected]> Co-authored-by: mgoin <[email protected]> Co-authored-by: Lucia Fang <[email protected]> Co-authored-by: Lucia Fang <[email protected]> Co-authored-by: NickLucche <[email protected]> Co-authored-by: Siyuan Fu <[email protected]> Co-authored-by: Matthew Bonanni <[email protected]> Co-authored-by: Xiaozhu Meng <[email protected]> Co-authored-by: Barry Kang <[email protected]> Signed-off-by: yewentao256 <[email protected]>
1 parent da71651 commit 0230cd0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+3918
-221
lines changed

cmake/external_projects/flashmla.cmake

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
1818
else()
1919
FetchContent_Declare(
2020
flashmla
21-
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
22-
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
21+
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
22+
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
2323
GIT_PROGRESS TRUE
2424
CONFIGURE_COMMAND ""
2525
BUILD_COMMAND ""
@@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
3333
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
3434
# Only build FlashMLA kernels if we are building for something compatible with
3535
# sm90a
36-
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
37-
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
36+
37+
set(SUPPORT_ARCHS)
38+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
39+
list(APPEND SUPPORT_ARCHS 9.0a)
40+
endif()
41+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
42+
list(APPEND SUPPORT_ARCHS 10.0a)
43+
endif()
44+
45+
46+
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
47+
if(FLASH_MLA_ARCHS)
48+
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
49+
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
50+
3851
set(FlashMLA_SOURCES
39-
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
40-
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
41-
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
42-
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
43-
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
52+
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
53+
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
54+
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
55+
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
56+
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
57+
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
58+
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
59+
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
60+
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
61+
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
62+
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
63+
)
64+
65+
set(FlashMLA_Extension_SOURCES
66+
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
67+
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
68+
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
69+
)
4470

4571
set(FlashMLA_INCLUDES
72+
${flashmla_SOURCE_DIR}/csrc
73+
${flashmla_SOURCE_DIR}/csrc/sm90
74+
${flashmla_SOURCE_DIR}/csrc/cutlass/include
75+
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
76+
)
77+
78+
set(FlashMLA_Extension_INCLUDES
79+
${flashmla_SOURCE_DIR}/csrc
80+
${flashmla_SOURCE_DIR}/csrc/sm90
81+
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
4682
${flashmla_SOURCE_DIR}/csrc/cutlass/include
47-
${flashmla_SOURCE_DIR}/csrc)
83+
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
84+
)
4885

4986
set_gencode_flags_for_srcs(
5087
SRCS "${FlashMLA_SOURCES}"
5188
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
5289

90+
set_gencode_flags_for_srcs(
91+
SRCS "${FlashMLA_Extension_SOURCES}"
92+
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
93+
5394
define_gpu_extension_target(
5495
_flashmla_C
5596
DESTINATION vllm
@@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
60101
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
61102
USE_SABI 3
62103
WITH_SOABI)
104+
105+
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
106+
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
107+
target_compile_options(_flashmla_C PRIVATE
108+
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
109+
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
110+
111+
define_gpu_extension_target(
112+
_flashmla_extension_C
113+
DESTINATION vllm
114+
LANGUAGE ${VLLM_GPU_LANG}
115+
SOURCES ${FlashMLA_Extension_SOURCES}
116+
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
117+
ARCHITECTURES ${VLLM_GPU_ARCHES}
118+
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
119+
USE_SABI 3
120+
WITH_SOABI)
121+
122+
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
123+
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
124+
target_compile_options(_flashmla_extension_C PRIVATE
125+
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
126+
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
63127
else()
64-
# Create an empty target for setup.py when not targeting sm90a systems
128+
# Create empty targets for setup.py when not targeting sm90a systems
65129
add_custom_target(_flashmla_C)
130+
add_custom_target(_flashmla_extension_C)
66131
endif()
67132

csrc/cache.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,11 @@ void cp_gather_cache(
5656
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
5757
torch::Tensor const& cu_seq_lens, // [BATCH+1]
5858
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
59+
60+
// Indexer K quantization and cache function
61+
void indexer_k_quant_and_cache(
62+
torch::Tensor& k, // [num_tokens, head_dim]
63+
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
64+
torch::Tensor& slot_mapping, // [num_tokens]
65+
int64_t quant_block_size, // quantization block size
66+
const std::string& scale_fmt);

0 commit comments

Comments
 (0)