Skip to content

Commit d4325ea

Browse files
committed
Fix cutlass_fused_experts_fp8
1 parent 44a2e5b commit d4325ea

File tree

2 files changed

+18
-7
lines changed
  • python/sglang/srt/layers

2 files changed

+18
-7
lines changed

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
QuantizationConfig,
2222
QuantizeMethodBase,
2323
)
24+
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
2425
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
2526
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
2627
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -690,4 +691,11 @@ def make_expert_input_scale_params_mapping(
690691
]
691692

692693
def should_fuse_routed_scaling_factor_in_topk(self):
693-
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
694+
if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod):
695+
return True
696+
if (
697+
isinstance(self.quant_method, Fp8MoEMethod)
698+
and self.quant_method.should_use_cutlass_fused_experts_fp8()
699+
):
700+
return True
701+
return False

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,14 @@ def process_weights_hip_scale_padding(self, layer: Module):
969969
)
970970
torch.cuda.empty_cache()
971971

972+
def should_use_cutlass_fused_experts_fp8(self):
973+
return (
974+
get_bool_env_var("SGLANG_CUTLASS_MOE")
975+
and self.cutlass_fp8_supported
976+
and self.block_quant
977+
and is_sm100_supported()
978+
)
979+
972980
def apply(
973981
self,
974982
layer: torch.nn.Module,
@@ -1019,12 +1027,7 @@ def apply(
10191027
if ret is not None:
10201028
return ret
10211029

1022-
if (
1023-
get_bool_env_var("SGLANG_CUTLASS_MOE")
1024-
and self.cutlass_fp8_supported
1025-
and self.block_quant
1026-
and is_sm100_supported()
1027-
):
1030+
if self.should_use_cutlass_fused_experts_fp8():
10281031
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
10291032

10301033
topk_weights, topk_ids, _ = topk_output

0 commit comments

Comments
 (0)