Skip to content

Commit 1e166cf

Browse files
pyc96shuaills
authored andcommitted
Fix MTP with Deepseek R1 Fp4 (sgl-project#7376)
1 parent 2dccc72 commit 1e166cf

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ def __init__(
330330
self.tp_rank = get_tensor_model_parallel_rank()
331331
self.num_experts = num_experts
332332
self.expert_map = None
333+
334+
if enable_flashinfer_moe and quant_config is None:
335+
logger.warning("Disable flashinfer MoE when quantization config is None.")
336+
enable_flashinfer_moe = False
337+
enable_ep_moe = False
338+
333339
self.enable_flashinfer_moe = enable_flashinfer_moe
334340
if enable_ep_moe:
335341
assert (

python/sglang/srt/models/deepseek_nextn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def __init__(
4444
prefix: str = "",
4545
) -> None:
4646
super().__init__()
47+
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
48+
logger.warning(
49+
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
50+
)
51+
quant_config = None
52+
4753
self.vocab_size = config.vocab_size
4854

4955
self.embed_tokens = VocabParallelEmbedding(

python/sglang/srt/models/deepseek_v2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2201,7 +2201,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
22012201
q_a_proj_weight = cached_a_proj[q_a_proj_name]
22022202
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
22032203
cat_dim = 0
2204-
if (
2204+
if self.quant_config is not None and (
22052205
self.quant_config.get_name() == "awq"
22062206
or self.quant_config.get_name() == "moe_wna16"
22072207
):
@@ -2232,6 +2232,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
22322232
for scale in ["k_scale", "v_scale"]:
22332233
if scale in name:
22342234
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
2235+
break
2236+
if name not in params_dict:
2237+
# modelopt ckpt contains not needed weights for MTP module:
2238+
# model.decoder.self_attn.attn_mqa.v_scale and
2239+
# model.decoder.self_attn.attn_mqa.k_scale
2240+
logger.warning(f"{name} not found in params_dict.")
2241+
continue
22352242
param = params_dict[name]
22362243
weight_loader = getattr(
22372244
param, "weight_loader", default_weight_loader

0 commit comments

Comments
 (0)