Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions sgl-kernel/csrc/cpu/moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -955,16 +955,16 @@ static inline void check_moe_scales(
}
}

#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
auto w1s = w1_scale.value(); \
auto w2s = w2_scale.value(); \
auto block_size_val = block_size.value(); \
int64_t block_size_N = block_size_val[0]; \
int64_t block_size_K = block_size_val[1]; \
TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \
TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \
TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \
TORCH_CHECK(w2s.size(DIM1) == N / block_size_K)
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
auto w1s = w1_scale.value(); \
auto w2s = w2_scale.value(); \
auto block_size_val = block_size.value(); \
int64_t block_size_N = block_size_val[0]; \
int64_t block_size_K = block_size_val[1]; \
TORCH_CHECK(w1s.size(DIM0) == div_up(2 * N, block_size_N)); \
TORCH_CHECK(w1s.size(DIM1) == div_up(K, block_size_K)); \
TORCH_CHECK(w2s.size(DIM0) == div_up(K, block_size_N)); \
TORCH_CHECK(w2s.size(DIM1) == div_up(N, block_size_K))

// hidden_states: [M, K]
// w1: [E, 2N, K]
Expand Down
14 changes: 10 additions & 4 deletions test/srt/cpu/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class TestFusedExperts(CustomTestCase):
topk_int8 = [3]

M_fp8 = [2, 121]
N_fp8 = [512]
K_fp8 = [256]
N_fp8 = [352, 512]
K_fp8 = [256, 320]
E_fp8 = [8]
topk_fp8 = [4]

Expand Down Expand Up @@ -201,8 +201,14 @@ def _fp8_moe(self, M, N, K, E, topk):
w2_fp32 = torch.randn(E, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
w1s = (
torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
* factor_for_scale
)
w2s = (
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
* factor_for_scale
)

w1_scaled = scaled_weight(w1, w1s)
w2_scaled = scaled_weight(w2, w2s)
Expand Down
23 changes: 19 additions & 4 deletions test/srt/cpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,33 @@ def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_facto

def scaled_weight(weight, scales):
E, N, K = weight.shape
pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N
pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K

if pad_N > 0 or pad_K > 0:
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))

weight_block = (
weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K)
weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K)
.permute(0, 1, 3, 2, 4)
.float()
.contiguous()
)
return (
(weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1))

weight_scaled = (
(
weight_block
* scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1)
)
.permute(0, 1, 3, 2, 4)
.contiguous()
.view(E, N, K)
)
if pad_N > 0 or pad_K > 0:
weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K)
weight_scaled = weight_scaled[..., :N, :K].contiguous()
else:
weight_scaled = weight_scaled.view(E, N, K)
return weight_scaled


def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
Expand Down
Loading