Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {

m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor) -> "
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def(
Expand Down
27 changes: 20 additions & 7 deletions sgl-kernel/csrc/moe/moe_fused_gate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output,
Params params) {
int tidx = threadIdx.x;
int64_t thread_row =
Expand Down Expand Up @@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = output_ptr[idx] / output_sum;
if (apply_routed_scaling_factor_on_output) {
output_ptr[idx] *= routed_scaling_factor;
}
}
}
}
Expand Down Expand Up @@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor) {
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(
input,
Expand All @@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params);
}

Expand All @@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
topk_group, \
topk, \
num_fused_shared_experts, \
routed_scaling_factor); \
routed_scaling_factor, \
apply_routed_scaling_factor_on_output); \
dispatched = true; \
} while (0)

Expand Down Expand Up @@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor) {
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
Expand All @@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params);
}

Expand All @@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor) {
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
Expand Down Expand Up @@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor);
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
Expand All @@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor);
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
Expand All @@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor);
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
}
Expand Down
3 changes: 2 additions & 1 deletion sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor);
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output);

void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output,
Expand Down
11 changes: 9 additions & 2 deletions sgl-kernel/python/sgl_kernel/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,21 @@ def moe_fused_gate(
topk,
num_fused_shared_experts=0,
routed_scaling_factor=0,
apply_routed_scaling_factor_on_output=False,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select expert groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
# num_fused_shared_experts: if > 0, the last several experts will be
# replaced with shared experts. the shared experts will be divided by the
# routed_scaling_factor - this is intended to cancel out later when routed+shared
# output is scaled so that shared experts are not scaled.
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
# apply_routed_scaling_factor_on_output: if true, output will be
# scaled by the routed_scaling_factor
return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor,
bias,
Expand All @@ -61,6 +67,7 @@ def moe_fused_gate(
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
)


Expand Down
7 changes: 6 additions & 1 deletion sgl-kernel/tests/test_moe_fused_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
],
)
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False])
def test_moe_fused_gate_combined(
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
):
num_experts, num_expert_group, topk_group, topk = params
dtype = torch.float32

Expand All @@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk=topk,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
ref_output, ref_indices = biased_grouped_topk(
scores,
Expand All @@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk_group=topk_group,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)

# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
Expand Down
Loading