Skip to content

Commit 4287a14

Browse files
trevor-mssssnow
authored andcommitted
[1/2] sgl-kernel: Fuse routed scaling factor into select_experts (#8364)
1 parent fc3789e commit 4287a14

File tree

5 files changed

+38
-12
lines changed

5 files changed

+38
-12
lines changed

sgl-kernel/csrc/common_extension.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
174174

175175
m.def(
176176
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
177-
"num_fused_shared_experts, float routed_scaling_factor) -> "
177+
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
178178
"(Tensor[])");
179179
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
180180
m.def(

sgl-kernel/csrc/moe/moe_fused_gate.cu

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
5959
int64_t topk,
6060
int64_t num_fused_shared_experts,
6161
double routed_scaling_factor,
62+
bool apply_routed_scaling_factor_on_output,
6263
Params params) {
6364
int tidx = threadIdx.x;
6465
int64_t thread_row =
@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
248249
for (int ii = 0; ii < topk; ++ii) {
249250
int64_t const idx = topk * thread_row + ii;
250251
output_ptr[idx] = output_ptr[idx] / output_sum;
252+
if (apply_routed_scaling_factor_on_output) {
253+
output_ptr[idx] *= routed_scaling_factor;
254+
}
251255
}
252256
}
253257
}
@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
282286
int64_t topk_group,
283287
int64_t topk,
284288
int64_t num_fused_shared_experts,
285-
double routed_scaling_factor) {
289+
double routed_scaling_factor,
290+
bool apply_routed_scaling_factor_on_output) {
286291
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
287292
moe_fused_gate_impl<T>(
288293
input,
@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
294299
topk,
295300
num_fused_shared_experts,
296301
routed_scaling_factor,
302+
apply_routed_scaling_factor_on_output,
297303
params);
298304
}
299305

@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
314320
topk_group, \
315321
topk, \
316322
num_fused_shared_experts, \
317-
routed_scaling_factor); \
323+
routed_scaling_factor, \
324+
apply_routed_scaling_factor_on_output); \
318325
dispatched = true; \
319326
} while (0)
320327

@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
342349
int64_t topk_group,
343350
int64_t topk,
344351
int64_t num_fused_shared_experts,
345-
double routed_scaling_factor) {
352+
double routed_scaling_factor,
353+
bool apply_routed_scaling_factor_on_output) {
346354
KernelParamsDynamic params;
347355
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
348356
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
361369
topk,
362370
num_fused_shared_experts,
363371
routed_scaling_factor,
372+
apply_routed_scaling_factor_on_output,
364373
params);
365374
}
366375

@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
374383
int64_t topk_group,
375384
int64_t topk,
376385
int64_t num_fused_shared_experts,
377-
double routed_scaling_factor) {
386+
double routed_scaling_factor,
387+
bool apply_routed_scaling_factor_on_output) {
378388
int64_t num_rows = input.size(0);
379389
int32_t num_experts = input.size(1);
380390
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
473483
topk_group,
474484
topk,
475485
num_fused_shared_experts,
476-
routed_scaling_factor);
486+
routed_scaling_factor,
487+
apply_routed_scaling_factor_on_output);
477488
} else if (input.scalar_type() == at::kHalf) {
478489
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
479490
input.data_ptr(),
@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
486497
topk_group,
487498
topk,
488499
num_fused_shared_experts,
489-
routed_scaling_factor);
500+
routed_scaling_factor,
501+
apply_routed_scaling_factor_on_output);
490502
} else if (input.scalar_type() == at::kFloat) {
491503
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
492504
input.data_ptr(),
@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
499511
topk_group,
500512
topk,
501513
num_fused_shared_experts,
502-
routed_scaling_factor);
514+
routed_scaling_factor,
515+
apply_routed_scaling_factor_on_output);
503516
} else {
504517
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
505518
}

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ std::vector<at::Tensor> moe_fused_gate(
243243
int64_t topk_group,
244244
int64_t topk,
245245
int64_t num_fused_shared_experts,
246-
double routed_scaling_factor);
246+
double routed_scaling_factor,
247+
bool apply_routed_scaling_factor_on_output);
247248

248249
void fp8_blockwise_scaled_grouped_mm(
249250
torch::Tensor& output,

sgl-kernel/python/sgl_kernel/moe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,21 @@ def moe_fused_gate(
4444
topk,
4545
num_fused_shared_experts=0,
4646
routed_scaling_factor=0,
47+
apply_routed_scaling_factor_on_output=False,
4748
):
4849
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
4950
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
5051
# as the group weight to select expert groups and then select topk experts within the selected groups
5152
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
5253
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
5354
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
54-
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
55-
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
55+
# num_fused_shared_experts: if > 0, the last several experts will be
56+
# replaced with shared experts. the shared experts will be divided by the
57+
# routed_scaling_factor - this is intended to cancel out later when routed+shared
58+
# output is scaled so that shared experts are not scaled.
59+
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
60+
# apply_routed_scaling_factor_on_output: if true, output will be
61+
# scaled by the routed_scaling_factor
5662
return torch.ops.sgl_kernel.moe_fused_gate.default(
5763
input_tensor,
5864
bias,
@@ -61,6 +67,7 @@ def moe_fused_gate(
6167
topk,
6268
num_fused_shared_experts,
6369
routed_scaling_factor,
70+
apply_routed_scaling_factor_on_output,
6471
)
6572

6673

sgl-kernel/tests/test_moe_fused_gate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
],
2020
)
2121
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
22-
def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
22+
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False])
23+
def test_moe_fused_gate_combined(
24+
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
25+
):
2326
num_experts, num_expert_group, topk_group, topk = params
2427
dtype = torch.float32
2528

@@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
3740
topk=topk,
3841
num_fused_shared_experts=num_fused_shared_experts,
3942
routed_scaling_factor=2.5,
43+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
4044
)
4145
ref_output, ref_indices = biased_grouped_topk(
4246
scores,
@@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
4852
topk_group=topk_group,
4953
num_fused_shared_experts=num_fused_shared_experts,
5054
routed_scaling_factor=2.5,
55+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
5156
)
5257

5358
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension

0 commit comments

Comments
 (0)