Skip to content

Commit 72ed08f

Browse files
authored
Revert "[1/2] sgl-kernel: Fuse routed scaling factor into select_experts (#8364)"
This reverts commit f642524.
1 parent ac6962c commit 72ed08f

File tree

5 files changed

+12
-38
lines changed

5 files changed

+12
-38
lines changed

sgl-kernel/csrc/common_extension.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
174174

175175
m.def(
176176
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
177-
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
177+
"num_fused_shared_experts, float routed_scaling_factor) -> "
178178
"(Tensor[])");
179179
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
180180
m.def(

sgl-kernel/csrc/moe/moe_fused_gate.cu

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl(
5959
int64_t topk,
6060
int64_t num_fused_shared_experts,
6161
double routed_scaling_factor,
62-
bool apply_routed_scaling_factor_on_output,
6362
Params params) {
6463
int tidx = threadIdx.x;
6564
int64_t thread_row =
@@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl(
249248
for (int ii = 0; ii < topk; ++ii) {
250249
int64_t const idx = topk * thread_row + ii;
251250
output_ptr[idx] = output_ptr[idx] / output_sum;
252-
if (apply_routed_scaling_factor_on_output) {
253-
output_ptr[idx] *= routed_scaling_factor;
254-
}
255251
}
256252
}
257253
}
@@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel(
286282
int64_t topk_group,
287283
int64_t topk,
288284
int64_t num_fused_shared_experts,
289-
double routed_scaling_factor,
290-
bool apply_routed_scaling_factor_on_output) {
285+
double routed_scaling_factor) {
291286
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
292287
moe_fused_gate_impl<T>(
293288
input,
@@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel(
299294
topk,
300295
num_fused_shared_experts,
301296
routed_scaling_factor,
302-
apply_routed_scaling_factor_on_output,
303297
params);
304298
}
305299

@@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel(
320314
topk_group, \
321315
topk, \
322316
num_fused_shared_experts, \
323-
routed_scaling_factor, \
324-
apply_routed_scaling_factor_on_output); \
317+
routed_scaling_factor); \
325318
dispatched = true; \
326319
} while (0)
327320

@@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
349342
int64_t topk_group,
350343
int64_t topk,
351344
int64_t num_fused_shared_experts,
352-
double routed_scaling_factor,
353-
bool apply_routed_scaling_factor_on_output) {
345+
double routed_scaling_factor) {
354346
KernelParamsDynamic params;
355347
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
356348
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic(
369361
topk,
370362
num_fused_shared_experts,
371363
routed_scaling_factor,
372-
apply_routed_scaling_factor_on_output,
373364
params);
374365
}
375366

@@ -383,8 +374,7 @@ std::vector<at::Tensor> moe_fused_gate(
383374
int64_t topk_group,
384375
int64_t topk,
385376
int64_t num_fused_shared_experts,
386-
double routed_scaling_factor,
387-
bool apply_routed_scaling_factor_on_output) {
377+
double routed_scaling_factor) {
388378
int64_t num_rows = input.size(0);
389379
int32_t num_experts = input.size(1);
390380
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
@@ -483,8 +473,7 @@ std::vector<at::Tensor> moe_fused_gate(
483473
topk_group,
484474
topk,
485475
num_fused_shared_experts,
486-
routed_scaling_factor,
487-
apply_routed_scaling_factor_on_output);
476+
routed_scaling_factor);
488477
} else if (input.scalar_type() == at::kHalf) {
489478
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
490479
input.data_ptr(),
@@ -497,8 +486,7 @@ std::vector<at::Tensor> moe_fused_gate(
497486
topk_group,
498487
topk,
499488
num_fused_shared_experts,
500-
routed_scaling_factor,
501-
apply_routed_scaling_factor_on_output);
489+
routed_scaling_factor);
502490
} else if (input.scalar_type() == at::kFloat) {
503491
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
504492
input.data_ptr(),
@@ -511,8 +499,7 @@ std::vector<at::Tensor> moe_fused_gate(
511499
topk_group,
512500
topk,
513501
num_fused_shared_experts,
514-
routed_scaling_factor,
515-
apply_routed_scaling_factor_on_output);
502+
routed_scaling_factor);
516503
} else {
517504
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
518505
}

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ std::vector<at::Tensor> moe_fused_gate(
243243
int64_t topk_group,
244244
int64_t topk,
245245
int64_t num_fused_shared_experts,
246-
double routed_scaling_factor,
247-
bool apply_routed_scaling_factor_on_output);
246+
double routed_scaling_factor);
248247

249248
void fp8_blockwise_scaled_grouped_mm(
250249
torch::Tensor& output,

sgl-kernel/python/sgl_kernel/moe.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,15 @@ def moe_fused_gate(
4444
topk,
4545
num_fused_shared_experts=0,
4646
routed_scaling_factor=0,
47-
apply_routed_scaling_factor_on_output=False,
4847
):
4948
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
5049
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
5150
# as the group weight to select expert groups and then select topk experts within the selected groups
5251
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
5352
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
5453
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
55-
# num_fused_shared_experts: if > 0, the last several experts will be
56-
# replaced with shared experts. the shared experts will be divided by the
57-
# routed_scaling_factor - this is intended to cancel out later when routed+shared
58-
# output is scaled so that shared experts are not scaled.
59-
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
60-
# apply_routed_scaling_factor_on_output: if true, output will be
61-
# scaled by the routed_scaling_factor
54+
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
55+
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
6256
return torch.ops.sgl_kernel.moe_fused_gate.default(
6357
input_tensor,
6458
bias,
@@ -67,7 +61,6 @@ def moe_fused_gate(
6761
topk,
6862
num_fused_shared_experts,
6963
routed_scaling_factor,
70-
apply_routed_scaling_factor_on_output,
7164
)
7265

7366

sgl-kernel/tests/test_moe_fused_gate.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
],
2020
)
2121
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
22-
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False])
23-
def test_moe_fused_gate_combined(
24-
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
25-
):
22+
def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
2623
num_experts, num_expert_group, topk_group, topk = params
2724
dtype = torch.float32
2825

@@ -40,7 +37,6 @@ def test_moe_fused_gate_combined(
4037
topk=topk,
4138
num_fused_shared_experts=num_fused_shared_experts,
4239
routed_scaling_factor=2.5,
43-
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
4440
)
4541
ref_output, ref_indices = biased_grouped_topk(
4642
scores,
@@ -52,7 +48,6 @@ def test_moe_fused_gate_combined(
5248
topk_group=topk_group,
5349
num_fused_shared_experts=num_fused_shared_experts,
5450
routed_scaling_factor=2.5,
55-
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
5651
)
5752

5853
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension

0 commit comments

Comments
 (0)