Skip to content

Conversation

Alcanderian
Copy link
Collaborator

@Alcanderian Alcanderian commented Jul 2, 2025

Motivation

ATTENTION: flashinfer cutlass moe may not work after installing trtllm wheel

UPD 072025, archieve up to 134 tps with tp8

UPD 070525, archieve up to 128 tps

SGLANG_TRTLLM_GEN_MOE_EP_SIZE=2 SGLANG_ENABLE_TRTLLM_GEN_MOE=1 python3 -m sglang.launch_server \
--model-path /dev/shm/DeepSeek-R1-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 \
--enable-flashinfer-moe --enable-ep-moe --enable-flashinfer-allreduce-fusion

ENV

pip3 install --no-cache-dir tensorrt-llm==1.0.0rc0 --no-deps
pip3 install --no-cache-dir tensorrt~=10.11.0 \
  tensorrt_cu12_bindings~=10.11.0 tensorrt_cu12~=10.11.0 \
  tensorrt_cu12_libs~=10.11.0 --no-deps
pip3 install --no-cache-dir nvtx mpi4py onnx onnx_graphsurgeon\>=0.5.2 \
  StrEnum accelerate\>=0.25.0 nvidia-modelopt\[torch\]~=0.31.0
pip3 install --no-cache-dir aenum backoff click_option_group \
  colored diffusers\>=0.27.0 etcd3 evaluate meson openai opencv-python-headless
pip3 install --no-cache-dir optimum ordered-set peft polygraphy \
  tornado\>=6.5.0 setuptools\<80 h5py==3.12.1 blake3

pip3 install -e "python[blackwell]"
pip3 install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps

RUN

SGLANG_ENABLE_TRTLLM_GEN_MOE=1 python3 -m sglang.launch_server --model-path /dev/shm/DeepSeek-R1-FP4 \
    --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-moe

# ACC
python3 benchmark/gsm8k/bench_sglang.py --parallel 1400 --num-questions 1400
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:40<00:00, 32.91it/s]
Accuracy: 0.951
Invalid: 0.000
Latency: 40.467 s
Output throughput: 3108.238 token/s

lm_eval
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9659|±  | 0.005|
|     |       |strict-match    |     5|exact_match|↑  |0.9492|±  | 0.006|

# TPS
[2025-07-02 11:51:13 TP0] Decode batch. #running-req: 1, #token: 1236, token usage: 0.00, cuda graph: True, gen throughput (token/s): 119.54, #queue-req:

TODO: add a custom cublas_mm kernel to let router gemm always give fp32 output, https://github.com/sgl-project/sglang/pull/7711/files#diff-5b9e34dd492bd8a14702a18b594721091092276fad1cf8736fba6ef1f33c1b04R247

Modifications

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Alcanderian, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on enabling FP4 quantization for Mixture of Experts (MoE) layers using TensorRT-LLM, aiming to optimize performance for models like DeepSeekV3. It introduces a new, optimized execution flow for MoE operations and includes minor adjustments to existing dependencies and MoE processing logic.

Highlights

  • TensorRT-LLM FP4 MoE Integration: I've integrated a new execution path for Mixture of Experts (MoE) layers that leverages TensorRT-LLM's FP4 quantization and block-scale MoE runner. This path is conditionally enabled and uses torch.ops.trtllm.fp4_quantize and torch.ops.trtllm.fp4_block_scale_moe_runner for optimized processing, with hardcoded parameters for DeepSeekV3.
  • FlashInfer Dependency Adjustment: I've updated the required flashinfer_python version in pyproject.toml and the corresponding version assertion in engine.py from 0.2.7 to 0.2.6.post1.
  • DeepSeekV2 MoE Optimization: For DeepSeekV2 models, I've increased the DUAL_STREAM_TOKEN_THRESHOLD from 1024 to 2048 and reordered the router_logits calculation within the forward_normal_dual_stream method to occur inside the alt_stream context, aiming to improve dual-stream processing efficiency.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for TensorRT-LLM FP4, which is a great optimization. The changes are mostly on the right track, but I've found a critical logic error in modelopt_quant.py that makes some code unreachable and could lead to incorrect behavior. Additionally, the new TRT-LLM implementation contains several hardcoded, model-specific values that should be generalized. Addressing these issues will make the new feature more robust and maintainable.

@Alcanderian
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables FP4 quantization for MoE layers using TensorRT-LLM, which is a significant performance optimization. The changes involve adding new code paths controlled by the SGLANG_ENABLE_TRTLLM_GEN_MOE environment variable, including new weight processing logic and a custom kernel runner for MoE. A comprehensive test suite from TensorRT-LLM has also been added to validate the new functionality.

My review focuses on improving the maintainability and robustness of the new code. I've pointed out several hardcoded values that are tied to specific model configurations or kernel internals, which could make the code brittle. I've suggested using named constants or deriving these values from configuration where possible. I also noted some areas where the code complexity has increased and could benefit from refactoring to improve clarity. Overall, this is a great step towards higher performance, and addressing these points will make the implementation more solid.

Comment on lines +898 to +908
if not ENABLE_TRTLMM_GEN_MOE:
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
layer.w13_weight_scale = Parameter(
w13_blockscale_swizzled, requires_grad=False
)
else:
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
self.trtllm_gen_process_expert_w3_w1_weight(layer)
self.trtllm_gen_process_expert_w3_w1_weight_scale_nvfp4(layer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic within process_weights_after_loading has become quite complex with the addition of the ENABLE_TRTLMM_GEN_MOE flag. Specifically, the reassignment of layer.w13_weight_scale (and w2_weight_scale) is confusing. Initially, it's a ModelWeightParameter, but in this if not ENABLE_TRTLMM_GEN_MOE branch, it's replaced with a Parameter containing swizzled scales. This can make the code harder to understand and maintain.

Consider refactoring the logic for the two paths into separate helper methods to improve clarity. For example, _process_weights_flashinfer() and _process_weights_trtllm(). Also, instead of reassigning layer.w13_weight_scale, it might be clearer to use a different attribute name for the swizzled scales, like layer.w13_blockscale_swizzled (as it was before), to avoid confusion about its type and content.

@trevor-m
Copy link
Collaborator

trevor-m commented Jul 3, 2025

Nice work! FYI the TRTLLM Gen MoE kernel is being added to flashinfer, so that will help to ease the deps setup.

self.quant_config = quant_config

if ENABLE_TRTLMM_GEN_MOE:
self.kernel = torch.ops.trtllm.nvfp4_gemm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trtllm nvfp4_gemm needs profiling to find the best config. How is that handled in sglang?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is probably none from what I can see. By default it'll use some heuristics, but they are not the best. There is autotuner in trtllm that handles this https://github.com/NVIDIA/TensorRT-LLM/pull/5207/files. I can bring that autotuner with flashninfer integration flashinfer-ai/flashinfer#1214

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have no auto tunner in sglang for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then the performance may not be optimal without profiling.


# NOTE: For some unknown reason, router_gemm seems degrade accept length.
if ENABLE_TRTLMM_GEN_MOE and not self.is_nextn:
return torch.ops.trtllm.dsv3_router_gemm_op(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this one different from the dsv3_router_gemm on line 255?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dsv3_router_gemm below do not have cublas fallback, and F.linear do not support fp16 input with fp32 output for now.

Comment on lines 444 to 445
if ENABLE_TRTLMM_GEN_MOE:
router_logits = self.gate(hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to reproduce trtllm profile's timeline

tile_tokens_dim = 8

# https://github.com/NVIDIA/TensorRT-LLM/blob/v1.0.0rc1/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py#L195
outputs = torch.ops.trtllm.fp4_block_scale_moe_runner(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if you have benchmarked the prefill perf separately. trtllm gen MoE is optimized for decoding, i.e., small input num_tokens.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not done the benchmark on prefill, but I can see that it aslo be called in trtllm prefill stage. And by the way are there any fp4_gemm interface can apply the same weight/scaling factor layout as trtllm gen moe? Because we cannot change the layout of weight/scaling factor in the forward stage.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe trtllm has added it: https://github.com/NVIDIA/TensorRT-LLM/blob/d4d21a106e8176bf20e627ee432cca5ef920c325/tests/unittest/_torch/thop/test_fp4_gemm_quantize.py#L122
But the weight layout is different from MoE.
Are you planning to run MoE as individual gemms in the forward stage? If so, this might be close to what you want:
https://github.com/NVIDIA/TensorRT-LLM/blob/d4d21a106e8176bf20e627ee432cca5ef920c325/tensorrt_llm/_torch/models/modeling_llama_min_latency.py#L47
The precision is not nvfp4.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two seem unable to provide any effective assistance.

Copy link
Collaborator

@hlu1 hlu1 Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wanted to say is that the layout changes will not be compatible with single gemms, because some of the layout shuffling is for swiglu fusion. The only possibility is the nvfp4 version of GatedMLP kernels, which shares the same weight layout as MoE. These kernels can be built. They are just not checked into trtllm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually now I think about it, the GatedMLP kernels would be optimized for low latency too.

@azhurkevich
Copy link
Collaborator

based tbh

@azhurkevich
Copy link
Collaborator

fyi, we merged flashinfer-ai/flashinfer#1214 and the plan is to enable these kernels through flashinfer in SGL

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.