@@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
18
18
else ()
19
19
FetchContent_Declare(
20
20
flashmla
21
- GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
22
- GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
21
+ GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
22
+ GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
23
23
GIT_PROGRESS TRUE
24
24
CONFIGURE_COMMAND ""
25
25
BUILD_COMMAND ""
@@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
33
33
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
34
34
# Only build FlashMLA kernels if we are building for something compatible with
35
35
# sm90a
36
- cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS} " )
37
- if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
36
+
37
+ set (SUPPORT_ARCHS)
38
+ if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
39
+ list (APPEND SUPPORT_ARCHS 9.0a)
40
+ endif ()
41
+ if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
42
+ list (APPEND SUPPORT_ARCHS 10.0a)
43
+ endif ()
44
+
45
+
46
+ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS} " "${CUDA_ARCHS} " )
47
+ if (FLASH_MLA_ARCHS)
48
+ set (VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS} )
49
+ list (APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math" )
50
+
38
51
set (FlashMLA_SOURCES
39
- ${flashmla_SOURCE_DIR} /csrc/flash_api.cpp
40
- ${flashmla_SOURCE_DIR} /csrc/kernels/get_mla_metadata.cu
41
- ${flashmla_SOURCE_DIR} /csrc/kernels/mla_combine.cu
42
- ${flashmla_SOURCE_DIR} /csrc/kernels/splitkv_mla.cu
43
- ${flashmla_SOURCE_DIR} /csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
52
+ ${flashmla_SOURCE_DIR} /csrc/torch_api.cpp
53
+ ${flashmla_SOURCE_DIR} /csrc/pybind.cpp
54
+ ${flashmla_SOURCE_DIR} /csrc/smxx/get_mla_metadata.cu
55
+ ${flashmla_SOURCE_DIR} /csrc/smxx/mla_combine.cu
56
+ ${flashmla_SOURCE_DIR} /csrc/sm90/decode/dense/splitkv_mla.cu
57
+ ${flashmla_SOURCE_DIR} /csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
58
+ ${flashmla_SOURCE_DIR} /csrc/sm90/prefill/sparse/fwd.cu
59
+ ${flashmla_SOURCE_DIR} /csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
60
+ ${flashmla_SOURCE_DIR} /csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
61
+ ${flashmla_SOURCE_DIR} /csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
62
+ ${flashmla_SOURCE_DIR} /csrc/sm100/prefill/sparse/fwd.cu
63
+ )
64
+
65
+ set (FlashMLA_Extension_SOURCES
66
+ ${flashmla_SOURCE_DIR} /csrc/extension/torch_api.cpp
67
+ ${flashmla_SOURCE_DIR} /csrc/extension/sm90/dense_fp8/pybind.cpp
68
+ ${flashmla_SOURCE_DIR} /csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
69
+ )
44
70
45
71
set (FlashMLA_INCLUDES
72
+ ${flashmla_SOURCE_DIR} /csrc
73
+ ${flashmla_SOURCE_DIR} /csrc/sm90
74
+ ${flashmla_SOURCE_DIR} /csrc/cutlass/include
75
+ ${flashmla_SOURCE_DIR} /csrc/cutlass/tools/util/include
76
+ )
77
+
78
+ set (FlashMLA_Extension_INCLUDES
79
+ ${flashmla_SOURCE_DIR} /csrc
80
+ ${flashmla_SOURCE_DIR} /csrc/sm90
81
+ ${flashmla_SOURCE_DIR} /csrc/extension/sm90/dense_fp8/
46
82
${flashmla_SOURCE_DIR} /csrc/cutlass/include
47
- ${flashmla_SOURCE_DIR} /csrc)
83
+ ${flashmla_SOURCE_DIR} /csrc/cutlass/tools/util/include
84
+ )
48
85
49
86
set_gencode_flags_for_srcs(
50
87
SRCS "${FlashMLA_SOURCES} "
51
88
CUDA_ARCHS "${FLASH_MLA_ARCHS} " )
52
89
90
+ set_gencode_flags_for_srcs(
91
+ SRCS "${FlashMLA_Extension_SOURCES} "
92
+ CUDA_ARCHS "${FLASH_MLA_ARCHS} " )
93
+
53
94
define_gpu_extension_target(
54
95
_flashmla_C
55
96
DESTINATION vllm
@@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
60
101
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
61
102
USE_SABI 3
62
103
WITH_SOABI)
104
+
105
+ # Keep Stable ABI for the module, but *not* for CUDA/C++ files.
106
+ # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
107
+ target_compile_options (_flashmla_C PRIVATE
108
+ $<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
109
+ $<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
110
+
111
+ define_gpu_extension_target(
112
+ _flashmla_extension_C
113
+ DESTINATION vllm
114
+ LANGUAGE ${VLLM_GPU_LANG}
115
+ SOURCES ${FlashMLA_Extension_SOURCES}
116
+ COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
117
+ ARCHITECTURES ${VLLM_GPU_ARCHES}
118
+ INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
119
+ USE_SABI 3
120
+ WITH_SOABI)
121
+
122
+ # Keep Stable ABI for the module, but *not* for CUDA/C++ files.
123
+ # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
124
+ target_compile_options (_flashmla_extension_C PRIVATE
125
+ $<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
126
+ $<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
63
127
else ()
64
- # Create an empty target for setup.py when not targeting sm90a systems
128
+ # Create empty targets for setup.py when not targeting sm90a systems
65
129
add_custom_target (_flashmla_C)
130
+ add_custom_target (_flashmla_extension_C)
66
131
endif ()
67
132
0 commit comments