Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def prepare_fp8_layer_for_marlin(*args, **kwargs):
_is_hip = is_hip()

if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
from aiter.ops.shuffle import shuffle_weight

_is_cuda = is_cuda()
Expand Down Expand Up @@ -487,7 +488,7 @@ def create_weights(

if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = (
torch.int32
torch.uint32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ why change this? It's also be used by NVIDIA GPU

Copy link
Collaborator Author

@HaiShaw HaiShaw Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a case for uint32 packed int4 in serialized checkpoint case (within the flag), else case is generic case of OCP/NV FP8 data read from checkpoint.

if get_bool_env_var("USE_INT4_WEIGHT")
else torch.float8_e4m3fn
)
Expand Down Expand Up @@ -822,12 +823,14 @@ def process_weights_hip_int4(self, layer: Module):
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -867,12 +870,14 @@ def process_weights_hip_scale_padding(self, layer: Module):

if get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -928,23 +933,25 @@ def apply(
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return asm_moe(
return ck_moe_2stages_win4(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=activation,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
if _is_hip and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
return asm_moe(
x,
layer.w13_weight,
Expand All @@ -957,14 +964,19 @@ def apply(
expert_mask=None,
)
else:
return asm_moe(
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
else:
# Expert fusion with FP8 quantization
Expand Down
Loading