Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def cutlass_fused_experts_fp8(
)

result = torch.empty((m, k), device=device, dtype=out_dtype)
apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential dtype mismatch here. The apply_shuffle_mul_sum kernel expects c2, result, and topk_weights to have the same dtype (out_dtype). However, topk_weights is passed without an explicit cast, which can lead to incorrect results or silent errors if its dtype does not match. It should be cast to out_dtype before being passed to the kernel.

Suggested change
apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))

apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a dtype check in c++ side void get_apply_shuffle_mul_sum_caller() to avoid any misusage in the future?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we want to make sure the c2, result, topk_weights to have the same dtype, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the checker hence the need to cast to out_dtype.
Also the kernel is very specifically written for MoE weighted sum, it will likely only appear in this file only.

TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype");

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Just curious why this mismatch wasn't captured by the previous commit.

return result


Expand Down
Loading