Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
7db8212
marlin optimization
jinzhen-lin Apr 18, 2025
48874db
fix
jinzhen-lin Apr 18, 2025
e13c8a1
fix
jinzhen-lin Apr 18, 2025
dd8a5e1
fix
jinzhen-lin Apr 18, 2025
fd16948
fix
jinzhen-lin Apr 18, 2025
ac5dc47
fix
jinzhen-lin Apr 18, 2025
cb8229c
fix
jinzhen-lin Apr 18, 2025
8bac124
fix moe performance bad cases
jinzhen-lin Apr 19, 2025
649701b
fix dense marlin performance bad cases
jinzhen-lin Apr 19, 2025
5fa7f33
some fix
jinzhen-lin Apr 19, 2025
eb3f2ed
fix
jinzhen-lin Apr 20, 2025
15daa36
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin Apr 20, 2025
fb85636
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin Apr 22, 2025
72a6ded
fix
jinzhen-lin Apr 23, 2025
367b5d9
remove kU8
jinzhen-lin Apr 23, 2025
90e1063
fix name
jinzhen-lin Apr 23, 2025
f42ac97
fix and add comment
jinzhen-lin Apr 23, 2025
110bbb8
fix
jinzhen-lin Apr 23, 2025
30dbb98
fix
jinzhen-lin Apr 23, 2025
720b900
fix
jinzhen-lin Apr 23, 2025
02d33ed
fix
jinzhen-lin Apr 23, 2025
c9adb76
fix
jinzhen-lin Apr 23, 2025
855efb0
fix
jinzhen-lin Apr 23, 2025
63c23a9
fix
jinzhen-lin Apr 23, 2025
4887c4d
fix
jinzhen-lin Apr 23, 2025
21047a3
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin Apr 23, 2025
1349629
fix 'variable "xxx" was declared but never referenced' warning
jinzhen-lin Apr 23, 2025
31f65ce
rerun
jinzhen-lin Apr 23, 2025
e2d255a
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin Apr 24, 2025
5f0370a
fix
jinzhen-lin Apr 24, 2025
d872ceb
fix
jinzhen-lin Apr 24, 2025
a538f5b
rerun
jinzhen-lin Apr 24, 2025
d29a12e
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin Apr 28, 2025
3dbd866
fix 'variable "xxx" was declared but never referenced'
jinzhen-lin Apr 28, 2025
fcac83b
fix
jinzhen-lin Apr 28, 2025
f55339e
fix
jinzhen-lin Apr 28, 2025
2656b1b
update
jinzhen-lin Apr 28, 2025
1dd2f2b
update
jinzhen-lin Apr 28, 2025
0523d57
fix
jinzhen-lin Apr 28, 2025
a3345bb
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin Apr 29, 2025
2df71dc
fix
jinzhen-lin Apr 29, 2025
bdb9f10
add comment
jinzhen-lin Apr 29, 2025
ade7fcd
fix
jinzhen-lin Apr 29, 2025
a86e539
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin Apr 30, 2025
305c0bd
fix
jinzhen-lin Apr 30, 2025
aa36125
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin Apr 30, 2025
da5d871
rerun
jinzhen-lin Apr 30, 2025
e1cec3c
rerun
jinzhen-lin Apr 30, 2025
645f16f
fix
jinzhen-lin May 1, 2025
37c4c43
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin May 1, 2025
8444f78
fix
jinzhen-lin May 1, 2025
77addcd
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin May 2, 2025
dd5cce2
fix
jinzhen-lin May 4, 2025
7dd8299
Merge remote-tracking branch 'origin/main' into marlin-kernel-optimiz…
jinzhen-lin May 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,52 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)

#
# For the Marlin kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)

message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")

if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
)

if (NOT marlin_generation_result EQUAL 0)
message(FATAL_ERROR "Marlin generation failed."
" Result: \"${marlin_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
else()
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
CACHE STRING "Last run Marlin generate script hash" FORCE)
message(STATUS "Marlin generation completed successfully.")
endif()
else()
message(STATUS "Marlin generation script has not changed, skipping generation.")
endif()

file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")

list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})

set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
Expand Down Expand Up @@ -644,7 +688,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output
Expand Down
1 change: 1 addition & 0 deletions csrc/moe/marlin_moe_wna16/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
kernel_*.cu
20 changes: 12 additions & 8 deletions csrc/moe/marlin_moe_wna16/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")

# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
Expand All @@ -52,21 +50,29 @@ def remove_old_kernels():

def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []

for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):

has_act_order = group_blocks == 0
if has_zp and has_act_order:
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue

# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue

k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
Expand All @@ -82,8 +88,6 @@ def generate_new_kernels():
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks,
is_zp_float=False,
)
Expand Down
10 changes: 4 additions & 6 deletions csrc/moe/marlin_moe_wna16/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce
bool use_fp32_reduce, int max_shared_mem

namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
Expand All @@ -33,11 +33,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);

Expand Down
Loading