Skip to content

Commit acede7d

Browse files
committed
Fuse routed scaling factor into select_experts for FP4
1 parent 9c138a0 commit acede7d

File tree

8 files changed

+87
-32
lines changed

8 files changed

+87
-32
lines changed

python/sglang/srt/layers/moe/topk.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
scoring_func: str = "softmax",
133133
correction_bias: Optional[torch.Tensor] = None,
134134
routed_scaling_factor: Optional[float] = None,
135+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
135136
):
136137
# NOTE: scoring_func is not used for now, but we keep it for future use
137138
# see https://github.com/sgl-project/sglang/pull/4505 for more details
@@ -147,6 +148,9 @@ def __init__(
147148
self.custom_routing_function = custom_routing_function
148149
self.correction_bias = correction_bias
149150
self.routed_scaling_factor = routed_scaling_factor
151+
self.apply_routed_scaling_factor_on_output = (
152+
apply_routed_scaling_factor_on_output
153+
)
150154

151155
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
152156

@@ -204,6 +208,7 @@ def forward_cuda(
204208
correction_bias=self.correction_bias,
205209
torch_native=torch_native,
206210
routed_scaling_factor=self.routed_scaling_factor,
211+
apply_routed_scaling_factor_on_output=self.apply_routed_scaling_factor_on_output,
207212
num_token_non_padded=num_token_non_padded,
208213
expert_location_dispatch_info=expert_location_dispatch_info,
209214
)
@@ -372,6 +377,7 @@ def grouped_topk_gpu(
372377
topk_group: int = 0,
373378
num_fused_shared_experts: int = 0,
374379
routed_scaling_factor: Optional[float] = None,
380+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
375381
num_token_non_padded: Optional[torch.Tensor] = None,
376382
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
377383
):
@@ -417,6 +423,8 @@ def grouped_topk_gpu(
417423
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
418424
)
419425
topk_weights = topk_weights / topk_weights_sum
426+
if apply_routed_scaling_factor_on_output:
427+
topk_weights *= routed_scaling_factor
420428

421429
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
422430
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
@@ -433,10 +441,12 @@ def grouped_topk_cpu(
433441
topk_group: int = 0,
434442
num_fused_shared_experts: int = 0,
435443
routed_scaling_factor: Optional[float] = None,
444+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
436445
num_token_non_padded: Optional[torch.Tensor] = None,
437446
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
438447
):
439448
assert expert_location_dispatch_info is None
449+
assert not apply_routed_scaling_factor_on_output
440450
return torch.ops.sgl_kernel.grouped_topk_cpu(
441451
hidden_states,
442452
gating_output,
@@ -461,6 +471,7 @@ def biased_grouped_topk_impl(
461471
topk_group: int = 0,
462472
num_fused_shared_experts: int = 0,
463473
routed_scaling_factor: Optional[float] = None,
474+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
464475
num_token_non_padded: Optional[torch.Tensor] = None,
465476
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
466477
):
@@ -509,7 +520,10 @@ def biased_grouped_topk_impl(
509520
if num_fused_shared_experts == 0
510521
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
511522
)
512-
topk_weights = topk_weights / topk_weights_sum
523+
if apply_routed_scaling_factor_on_output:
524+
topk_weights = topk_weights / topk_weights_sum * routed_scaling_factor
525+
else:
526+
topk_weights = topk_weights / topk_weights_sum
513527

514528
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
515529
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
@@ -549,6 +563,7 @@ def biased_grouped_topk_gpu(
549563
num_expert_group: int = 0,
550564
topk_group: int = 0,
551565
num_fused_shared_experts: int = 0,
566+
apply_routed_scaling_factor_on_output: bool = False,
552567
routed_scaling_factor: Optional[float] = None,
553568
num_token_non_padded: Optional[torch.Tensor] = None,
554569
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
@@ -571,6 +586,7 @@ def biased_grouped_topk_gpu(
571586
topk,
572587
num_fused_shared_experts,
573588
routed_scaling_factor,
589+
apply_routed_scaling_factor_on_output,
574590
)
575591
# TODO merge into kernel
576592
if (expert_location_dispatch_info is not None) or (
@@ -581,6 +597,7 @@ def biased_grouped_topk_gpu(
581597
)
582598
return topk_weights, topk_ids
583599
elif _use_aiter:
600+
assert not apply_routed_scaling_factor_on_output, "Not implemented"
584601
token = gating_output.shape[0]
585602
device = gating_output.device
586603
assert (
@@ -610,6 +627,7 @@ def biased_grouped_topk_gpu(
610627
topk_group,
611628
num_fused_shared_experts=num_fused_shared_experts,
612629
routed_scaling_factor=routed_scaling_factor,
630+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
613631
num_token_non_padded=num_token_non_padded,
614632
expert_location_dispatch_info=expert_location_dispatch_info,
615633
)
@@ -626,10 +644,12 @@ def biased_grouped_topk_cpu(
626644
compiled: bool = True,
627645
num_fused_shared_experts: int = 0,
628646
routed_scaling_factor: Optional[float] = None,
647+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
629648
num_token_non_padded: Optional[torch.Tensor] = None,
630649
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
631650
):
632651
assert expert_location_dispatch_info is None
652+
assert not apply_routed_scaling_factor_on_output
633653
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
634654
hidden_states,
635655
gating_output,
@@ -669,6 +689,7 @@ def select_experts(
669689
correction_bias: Optional[torch.Tensor] = None,
670690
torch_native: bool = False,
671691
routed_scaling_factor: Optional[float] = None,
692+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
672693
num_token_non_padded: Optional[torch.Tensor] = None,
673694
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
674695
) -> TopKOutput:
@@ -694,6 +715,7 @@ def select_experts(
694715
topk_group=topk_group,
695716
num_fused_shared_experts=num_fused_shared_experts,
696717
routed_scaling_factor=routed_scaling_factor,
718+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
697719
num_token_non_padded=num_token_non_padded,
698720
expert_location_dispatch_info=expert_location_dispatch_info,
699721
)
@@ -707,6 +729,7 @@ def select_experts(
707729
num_expert_group=num_expert_group,
708730
topk_group=topk_group,
709731
num_fused_shared_experts=num_fused_shared_experts,
732+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
710733
routed_scaling_factor=routed_scaling_factor,
711734
num_token_non_padded=num_token_non_padded,
712735
expert_location_dispatch_info=expert_location_dispatch_info,
@@ -716,13 +739,15 @@ def select_experts(
716739
num_token_non_padded is None
717740
), "num_token_non_padded is not yet supported in fused_topk_native"
718741
assert expert_location_dispatch_info is None
742+
assert not apply_routed_scaling_factor_on_output, "Not implemented"
719743
topk_weights, topk_ids = fused_topk_native(
720744
hidden_states=hidden_states,
721745
gating_output=router_logits,
722746
topk=top_k,
723747
renormalize=renormalize,
724748
)
725749
elif custom_routing_function is None:
750+
assert not apply_routed_scaling_factor_on_output, "Not implemented"
726751
# Qwen3MOE uses fused_topk
727752
topk_weights, topk_ids = fused_topk(
728753
hidden_states=hidden_states,
@@ -737,6 +762,7 @@ def select_experts(
737762
num_token_non_padded is None
738763
), "num_token_non_padded is not yet supported in custom_routing_function"
739764
assert expert_location_dispatch_info is None
765+
assert not apply_routed_scaling_factor_on_output, "Not implemented"
740766
topk_weights, topk_ids = custom_routing_function(
741767
hidden_states=hidden_states,
742768
gating_output=router_logits,

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,7 @@ def apply(
967967
tp_rank: Optional[int] = None,
968968
tp_size: Optional[int] = None,
969969
) -> torch.Tensor:
970+
# Scale by routed_scaling_factor is fused into select_experts.
970971
assert activation == "silu", "Only SiLU activation is supported."
971972

972973
if self.enable_flashinfer_cutlass_moe:
@@ -997,8 +998,6 @@ def apply(
997998
tp_rank=tp_rank,
998999
tune_max_num_tokens=next_power_of_2(x.shape[0]),
9991000
)[0]
1000-
if routed_scaling_factor is not None:
1001-
output *= routed_scaling_factor
10021001
return output
10031002

10041003
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -1019,6 +1018,4 @@ def apply(
10191018
params=layer.cutlass_moe_params,
10201019
apply_router_weight_on_input=apply_router_weight_on_input,
10211020
).to(x.dtype)
1022-
if routed_scaling_factor is not None:
1023-
output *= routed_scaling_factor
10241021
return output

python/sglang/srt/models/deepseek_v2.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
use_flashinfer_trtllm_moe,
6363
)
6464
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
65+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
6566
from sglang.srt.layers.moe.topk import TopK
6667
from sglang.srt.layers.quantization import deep_gemm_wrapper
6768
from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -80,6 +81,7 @@
8081
from sglang.srt.layers.quantization.int8_utils import (
8182
block_dequant as int8_block_dequant,
8283
)
84+
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
8385
from sglang.srt.layers.radix_attention import RadixAttention
8486
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
8587
from sglang.srt.layers.utils import is_sm100_supported
@@ -306,21 +308,6 @@ def __init__(
306308
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
307309
)
308310

309-
self.topk = (
310-
TopK(
311-
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
312-
renormalize=config.norm_topk_prob,
313-
use_grouped_topk=True,
314-
num_expert_group=config.n_group,
315-
num_fused_shared_experts=self.num_fused_shared_experts,
316-
topk_group=config.topk_group,
317-
correction_bias=self.gate.e_score_correction_bias,
318-
routed_scaling_factor=self.routed_scaling_factor,
319-
)
320-
if not use_flashinfer_trtllm_moe
321-
else None
322-
)
323-
324311
self.experts = get_moe_impl_class()(
325312
num_experts=config.n_routed_experts
326313
+ self.num_fused_shared_experts
@@ -360,6 +347,25 @@ def __init__(
360347
),
361348
)
362349

350+
apply_routed_scaling_factor_on_output = isinstance(
351+
self.experts, FusedMoE
352+
) and isinstance(self.experts.quant_method, ModelOptNvFp4FusedMoEMethod)
353+
self.topk = (
354+
TopK(
355+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
356+
renormalize=config.norm_topk_prob,
357+
use_grouped_topk=True,
358+
num_expert_group=config.n_group,
359+
num_fused_shared_experts=self.num_fused_shared_experts,
360+
topk_group=config.topk_group,
361+
correction_bias=self.gate.e_score_correction_bias,
362+
routed_scaling_factor=self.routed_scaling_factor,
363+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
364+
)
365+
if not use_flashinfer_trtllm_moe
366+
else None
367+
)
368+
363369
self.shared_experts_is_int8 = False
364370
self.shared_experts_is_fp8 = False
365371
self.shared_experts_weight_block_size = None

sgl-kernel/csrc/common_extension.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
174174

175175
m.def(
176176
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
177-
"num_fused_shared_experts, float routed_scaling_factor) -> "
177+
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
178178
"(Tensor[])");
179179
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
180180
m.def(

sgl-kernel/csrc/moe/moe_fused_gate.cu

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
5959
int64_t topk,
6060
int64_t num_fused_shared_experts,
6161
double routed_scaling_factor,
62+
bool apply_routed_scaling_factor_on_output,
6263
Params params) {
6364
int tidx = threadIdx.x;
6465
int64_t thread_row =
@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
248249
for (int ii = 0; ii < topk; ++ii) {
249250
int64_t const idx = topk * thread_row + ii;
250251
output_ptr[idx] = output_ptr[idx] / output_sum;
252+
if (apply_routed_scaling_factor_on_output) {
253+
output_ptr[idx] *= routed_scaling_factor;
254+
}
251255
}
252256
}
253257
}
@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
282286
int64_t topk_group,
283287
int64_t topk,
284288
int64_t num_fused_shared_experts,
285-
double routed_scaling_factor) {
289+
double routed_scaling_factor,
290+
bool apply_routed_scaling_factor_on_output) {
286291
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
287292
moe_fused_gate_impl<T>(
288293
input,
@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
294299
topk,
295300
num_fused_shared_experts,
296301
routed_scaling_factor,
302+
apply_routed_scaling_factor_on_output,
297303
params);
298304
}
299305

@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
314320
topk_group, \
315321
topk, \
316322
num_fused_shared_experts, \
317-
routed_scaling_factor); \
323+
routed_scaling_factor, \
324+
apply_routed_scaling_factor_on_output); \
318325
dispatched = true; \
319326
} while (0)
320327

@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
342349
int64_t topk_group,
343350
int64_t topk,
344351
int64_t num_fused_shared_experts,
345-
double routed_scaling_factor) {
352+
double routed_scaling_factor,
353+
bool apply_routed_scaling_factor_on_output) {
346354
KernelParamsDynamic params;
347355
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
348356
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
361369
topk,
362370
num_fused_shared_experts,
363371
routed_scaling_factor,
372+
apply_routed_scaling_factor_on_output,
364373
params);
365374
}
366375

@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
374383
int64_t topk_group,
375384
int64_t topk,
376385
int64_t num_fused_shared_experts,
377-
double routed_scaling_factor) {
386+
double routed_scaling_factor,
387+
bool apply_routed_scaling_factor_on_output) {
378388
int64_t num_rows = input.size(0);
379389
int32_t num_experts = input.size(1);
380390
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
473483
topk_group,
474484
topk,
475485
num_fused_shared_experts,
476-
routed_scaling_factor);
486+
routed_scaling_factor,
487+
apply_routed_scaling_factor_on_output);
477488
} else if (input.scalar_type() == at::kHalf) {
478489
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
479490
input.data_ptr(),
@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
486497
topk_group,
487498
topk,
488499
num_fused_shared_experts,
489-
routed_scaling_factor);
500+
routed_scaling_factor,
501+
apply_routed_scaling_factor_on_output);
490502
} else if (input.scalar_type() == at::kFloat) {
491503
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
492504
input.data_ptr(),
@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
499511
topk_group,
500512
topk,
501513
num_fused_shared_experts,
502-
routed_scaling_factor);
514+
routed_scaling_factor,
515+
apply_routed_scaling_factor_on_output);
503516
} else {
504517
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
505518
}

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ std::vector<at::Tensor> moe_fused_gate(
244244
int64_t topk_group,
245245
int64_t topk,
246246
int64_t num_fused_shared_experts,
247-
double routed_scaling_factor);
247+
double routed_scaling_factor,
248+
bool apply_routed_scaling_factor_on_output);
248249

249250
void fp8_blockwise_scaled_grouped_mm(
250251
torch::Tensor& output,

0 commit comments

Comments
 (0)