@@ -132,6 +132,7 @@ def __init__(
132
132
scoring_func : str = "softmax" ,
133
133
correction_bias : Optional [torch .Tensor ] = None ,
134
134
routed_scaling_factor : Optional [float ] = None ,
135
+ apply_routed_scaling_factor_on_output : Optional [bool ] = False ,
135
136
):
136
137
# NOTE: scoring_func is not used for now, but we keep it for future use
137
138
# see https://github.com/sgl-project/sglang/pull/4505 for more details
@@ -147,6 +148,9 @@ def __init__(
147
148
self .custom_routing_function = custom_routing_function
148
149
self .correction_bias = correction_bias
149
150
self .routed_scaling_factor = routed_scaling_factor
151
+ self .apply_routed_scaling_factor_on_output = (
152
+ apply_routed_scaling_factor_on_output
153
+ )
150
154
151
155
self .use_triton_kernels = global_server_args_dict ["enable_triton_kernel_moe" ]
152
156
@@ -204,6 +208,7 @@ def forward_cuda(
204
208
correction_bias = self .correction_bias ,
205
209
torch_native = torch_native ,
206
210
routed_scaling_factor = self .routed_scaling_factor ,
211
+ apply_routed_scaling_factor_on_output = self .apply_routed_scaling_factor_on_output ,
207
212
num_token_non_padded = num_token_non_padded ,
208
213
expert_location_dispatch_info = expert_location_dispatch_info ,
209
214
)
@@ -372,6 +377,7 @@ def grouped_topk_gpu(
372
377
topk_group : int = 0 ,
373
378
num_fused_shared_experts : int = 0 ,
374
379
routed_scaling_factor : Optional [float ] = None ,
380
+ apply_routed_scaling_factor_on_output : Optional [bool ] = False ,
375
381
num_token_non_padded : Optional [torch .Tensor ] = None ,
376
382
expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
377
383
):
@@ -417,6 +423,8 @@ def grouped_topk_gpu(
417
423
else topk_weights [:, :- 1 ].sum (dim = - 1 , keepdim = True )
418
424
)
419
425
topk_weights = topk_weights / topk_weights_sum
426
+ if apply_routed_scaling_factor_on_output :
427
+ topk_weights *= routed_scaling_factor
420
428
421
429
topk_weights , topk_ids = topk_weights .to (torch .float32 ), topk_ids .to (torch .int32 )
422
430
topk_ids = topk_ids_logical_to_physical (topk_ids , expert_location_dispatch_info )
@@ -433,10 +441,12 @@ def grouped_topk_cpu(
433
441
topk_group : int = 0 ,
434
442
num_fused_shared_experts : int = 0 ,
435
443
routed_scaling_factor : Optional [float ] = None ,
444
+ apply_routed_scaling_factor_on_output : Optional [bool ] = False ,
436
445
num_token_non_padded : Optional [torch .Tensor ] = None ,
437
446
expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
438
447
):
439
448
assert expert_location_dispatch_info is None
449
+ assert not apply_routed_scaling_factor_on_output
440
450
return torch .ops .sgl_kernel .grouped_topk_cpu (
441
451
hidden_states ,
442
452
gating_output ,
@@ -461,6 +471,7 @@ def biased_grouped_topk_impl(
461
471
topk_group : int = 0 ,
462
472
num_fused_shared_experts : int = 0 ,
463
473
routed_scaling_factor : Optional [float ] = None ,
474
+ apply_routed_scaling_factor_on_output : Optional [bool ] = False ,
464
475
num_token_non_padded : Optional [torch .Tensor ] = None ,
465
476
expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
466
477
):
@@ -509,7 +520,10 @@ def biased_grouped_topk_impl(
509
520
if num_fused_shared_experts == 0
510
521
else topk_weights [:, :- 1 ].sum (dim = - 1 , keepdim = True )
511
522
)
512
- topk_weights = topk_weights / topk_weights_sum
523
+ if apply_routed_scaling_factor_on_output :
524
+ topk_weights = topk_weights / topk_weights_sum * routed_scaling_factor
525
+ else :
526
+ topk_weights = topk_weights / topk_weights_sum
513
527
514
528
topk_weights , topk_ids = topk_weights .to (torch .float32 ), topk_ids .to (torch .int32 )
515
529
topk_ids = topk_ids_logical_to_physical (topk_ids , expert_location_dispatch_info )
@@ -549,6 +563,7 @@ def biased_grouped_topk_gpu(
549
563
num_expert_group : int = 0 ,
550
564
topk_group : int = 0 ,
551
565
num_fused_shared_experts : int = 0 ,
566
+ apply_routed_scaling_factor_on_output : bool = False ,
552
567
routed_scaling_factor : Optional [float ] = None ,
553
568
num_token_non_padded : Optional [torch .Tensor ] = None ,
554
569
expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
@@ -571,6 +586,7 @@ def biased_grouped_topk_gpu(
571
586
topk ,
572
587
num_fused_shared_experts ,
573
588
routed_scaling_factor ,
589
+ apply_routed_scaling_factor_on_output ,
574
590
)
575
591
# TODO merge into kernel
576
592
if (expert_location_dispatch_info is not None ) or (
@@ -581,6 +597,7 @@ def biased_grouped_topk_gpu(
581
597
)
582
598
return topk_weights , topk_ids
583
599
elif _use_aiter :
600
+ assert not apply_routed_scaling_factor_on_output , "Not implemented"
584
601
token = gating_output .shape [0 ]
585
602
device = gating_output .device
586
603
assert (
@@ -610,6 +627,7 @@ def biased_grouped_topk_gpu(
610
627
topk_group ,
611
628
num_fused_shared_experts = num_fused_shared_experts ,
612
629
routed_scaling_factor = routed_scaling_factor ,
630
+ apply_routed_scaling_factor_on_output = apply_routed_scaling_factor_on_output ,
613
631
num_token_non_padded = num_token_non_padded ,
614
632
expert_location_dispatch_info = expert_location_dispatch_info ,
615
633
)
@@ -626,10 +644,12 @@ def biased_grouped_topk_cpu(
626
644
compiled : bool = True ,
627
645
num_fused_shared_experts : int = 0 ,
628
646
routed_scaling_factor : Optional [float ] = None ,
647
+ apply_routed_scaling_factor_on_output : Optional [bool ] = False ,
629
648
num_token_non_padded : Optional [torch .Tensor ] = None ,
630
649
expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
631
650
):
632
651
assert expert_location_dispatch_info is None
652
+ assert not apply_routed_scaling_factor_on_output
633
653
return torch .ops .sgl_kernel .biased_grouped_topk_cpu (
634
654
hidden_states ,
635
655
gating_output ,
@@ -669,6 +689,7 @@ def select_experts(
669
689
correction_bias : Optional [torch .Tensor ] = None ,
670
690
torch_native : bool = False ,
671
691
routed_scaling_factor : Optional [float ] = None ,
692
+ apply_routed_scaling_factor_on_output : Optional [bool ] = False ,
672
693
num_token_non_padded : Optional [torch .Tensor ] = None ,
673
694
expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
674
695
) -> TopKOutput :
@@ -694,6 +715,7 @@ def select_experts(
694
715
topk_group = topk_group ,
695
716
num_fused_shared_experts = num_fused_shared_experts ,
696
717
routed_scaling_factor = routed_scaling_factor ,
718
+ apply_routed_scaling_factor_on_output = apply_routed_scaling_factor_on_output ,
697
719
num_token_non_padded = num_token_non_padded ,
698
720
expert_location_dispatch_info = expert_location_dispatch_info ,
699
721
)
@@ -707,6 +729,7 @@ def select_experts(
707
729
num_expert_group = num_expert_group ,
708
730
topk_group = topk_group ,
709
731
num_fused_shared_experts = num_fused_shared_experts ,
732
+ apply_routed_scaling_factor_on_output = apply_routed_scaling_factor_on_output ,
710
733
routed_scaling_factor = routed_scaling_factor ,
711
734
num_token_non_padded = num_token_non_padded ,
712
735
expert_location_dispatch_info = expert_location_dispatch_info ,
@@ -716,13 +739,15 @@ def select_experts(
716
739
num_token_non_padded is None
717
740
), "num_token_non_padded is not yet supported in fused_topk_native"
718
741
assert expert_location_dispatch_info is None
742
+ assert not apply_routed_scaling_factor_on_output , "Not implemented"
719
743
topk_weights , topk_ids = fused_topk_native (
720
744
hidden_states = hidden_states ,
721
745
gating_output = router_logits ,
722
746
topk = top_k ,
723
747
renormalize = renormalize ,
724
748
)
725
749
elif custom_routing_function is None :
750
+ assert not apply_routed_scaling_factor_on_output , "Not implemented"
726
751
# Qwen3MOE uses fused_topk
727
752
topk_weights , topk_ids = fused_topk (
728
753
hidden_states = hidden_states ,
@@ -737,6 +762,7 @@ def select_experts(
737
762
num_token_non_padded is None
738
763
), "num_token_non_padded is not yet supported in custom_routing_function"
739
764
assert expert_location_dispatch_info is None
765
+ assert not apply_routed_scaling_factor_on_output , "Not implemented"
740
766
topk_weights , topk_ids = custom_routing_function (
741
767
hidden_states = hidden_states ,
742
768
gating_output = router_logits ,
0 commit comments