Skip to content

Commit 9659e2e

Browse files
authored
Enable optimized Qwen3 and Qwen3Moe (#3688)
1 parent 50c7778 commit 9659e2e

File tree

23 files changed

+1238
-354
lines changed

23 files changed

+1238
-354
lines changed

csrc/cpu/aten/DSMoE.cpp

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ constexpr int block_size_m() {
150150
constexpr int block_size_n() {
151151
return 8 * TILE_N;
152152
}
153+
constexpr int block_size_n2() {
154+
return 2 * TILE_N;
155+
}
153156
// convert to vnni format
154157
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
155158
//
@@ -170,45 +173,76 @@ inline void pack_vnni(
170173
}
171174
}
172175

173-
at::Tensor convert_weight_packed_bf16(at::Tensor& weight) {
176+
std::tuple<at::Tensor, at::Tensor> convert_weight_packed_moe_bf16(
177+
at::Tensor& w1,
178+
at::Tensor& w2) {
174179
// weight : [E, OC, IC]
175180
// w1 : [E, 2N, K]
176181
// w2 : [E, K, N]
177-
CHECK_DIM(3, weight);
178-
const auto st = weight.scalar_type();
179-
const int E = weight.size(0);
180-
const int OC = weight.size(1);
181-
const int IC = weight.size(2);
182+
CHECK_DIM(3, w1);
183+
CHECK_DIM(3, w2);
184+
const auto st1 = w1.scalar_type();
185+
const auto st2 = w2.scalar_type();
186+
TORCH_CHECK(st1 == st2, "weight type mismatch");
187+
const int E1 = w1.size(0);
188+
const int E2 = w2.size(0);
189+
const int OC1 = w1.size(1);
190+
const int IC1 = w1.size(2);
191+
const int OC2 = w2.size(1);
192+
const int IC2 = w2.size(2);
193+
TORCH_CHECK(E1 == E2 && IC2 * 2 == OC1 && IC1 == OC2, "weight size mismatch");
194+
const int N = IC2;
195+
const int K = IC1;
182196
// we handle 2 TILE_N at a time.
183-
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
184-
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
185-
constexpr int BLOCK_N = block_size_n();
197+
TORCH_CHECK(N % TILE_N == 0, "invalid weight out features ", N);
198+
TORCH_CHECK(K % TILE_K == 0, "invalid weight input features ", K);
199+
int BLOCK_N = block_size_n();
200+
if (N < 2 * BLOCK_N) {
201+
BLOCK_N = block_size_n2();
202+
}
203+
186204
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
187-
auto packed_weight = at::empty({E, OC, IC}, weight.options());
188-
const int stride = OC * IC;
205+
auto packed_weight1 = at::empty({E1, OC1, IC1}, w1.options());
206+
auto packed_weight2 = at::empty({E2, OC2, IC2}, w2.options());
207+
const int stride1 = OC1 * IC1;
208+
const int stride2 = OC2 * IC2;
189209
// TODO: add float8 support
190210
TORCH_CHECK(
191-
st == at::kBFloat16 || st == at::kHalf,
211+
st1 == at::kBFloat16 || st1 == at::kHalf,
192212
"expect weight to be bfloat16 or float16.");
193-
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "conver_weight_packed_impl", [&] {
194-
const scalar_t* w_data = weight.data_ptr<scalar_t>();
195-
scalar_t* packed_data = packed_weight.data_ptr<scalar_t>();
213+
AT_DISPATCH_REDUCED_FLOATING_TYPES(st1, "conver_weight_packed_impl", [&] {
214+
const scalar_t* w_data1 = w1.data_ptr<scalar_t>();
215+
const scalar_t* w_data2 = w2.data_ptr<scalar_t>();
216+
scalar_t* packed_data1 = packed_weight1.data_ptr<scalar_t>();
217+
scalar_t* packed_data2 = packed_weight2.data_ptr<scalar_t>();
196218
// parallel on {E}
197-
at::parallel_for(0, E, 0, [&](int begin, int end) {
219+
at::parallel_for(0, E1, 0, [&](int begin, int end) {
220+
for (int e = begin; e < end; ++e) {
221+
for (int n = 0; n < OC1; n += BLOCK_N) {
222+
int n_size = std::min(BLOCK_N, OC1 - n);
223+
pack_vnni<scalar_t>(
224+
packed_data1 + e * stride1 + n * IC1,
225+
w_data1 + e * stride1 + n * IC1,
226+
n_size,
227+
IC1);
228+
}
229+
}
230+
});
231+
at::parallel_for(0, E2, 0, [&](int begin, int end) {
198232
for (int e = begin; e < end; ++e) {
199-
for (int n = 0; n < OC; n += BLOCK_N) {
200-
int n_size = std::min(BLOCK_N, OC - n);
233+
for (int n = 0; n < OC2; n += BLOCK_N) {
234+
int n_size = std::min(BLOCK_N, OC2 - n);
201235
pack_vnni<scalar_t>(
202-
packed_data + e * stride + n * IC,
203-
w_data + e * stride + n * IC,
236+
packed_data2 + e * stride2 + n * IC2,
237+
w_data2 + e * stride2 + n * IC2,
204238
n_size,
205-
IC);
239+
IC2);
206240
}
207241
}
208242
});
209243
});
210244

211-
return packed_weight;
245+
return std::make_tuple(packed_weight1, packed_weight2);
212246
}
213247

214248
template <typename scalar_t, int SIZE>
@@ -485,10 +519,11 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
485519
"grouped_topk(Tensor hidden_states, Tensor gating_output, \
486520
int topk, bool renormalize, int num_expert_group, int topk_group, Tensor e_score_correction_bias, Tensor routed_scaling_factor) -> (Tensor, Tensor)");
487521
m.impl("grouped_topk", c10::DispatchKey::CPU, torch_ipex::cpu::grouped_topk);
488-
m.def("convert_weight_packed_bf16(Tensor weight) -> Tensor");
522+
m.def(
523+
"convert_weight_packed_moe_bf16(Tensor weight1, Tensor weight2) -> (Tensor, Tensor)");
489524
m.impl(
490-
"convert_weight_packed_bf16",
525+
"convert_weight_packed_moe_bf16",
491526
c10::DispatchKey::CPU,
492-
torch_ipex::cpu::convert_weight_packed_bf16);
527+
torch_ipex::cpu::convert_weight_packed_moe_bf16);
493528
}
494529
} // namespace

csrc/cpu/aten/kernels/DSMoEKrnl.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,13 @@ void fused_experts_kernel_impl(
479479

480480
constexpr int BLOCK_M = block_size_m();
481481
constexpr int T_BLOCK_N = block_size_n(); // Tuned block_n for AMX
482+
constexpr int T_BLOCK_N2 = block_size_n2();
482483
constexpr int Q_BLOCK_N = WOQ_N_BLOCK_SIZE; // Tuned block_n for WOQ
483484
int BLOCK_N = is_woq ? Q_BLOCK_N : T_BLOCK_N;
484485
int Q_BLOCK_K = is_woq ? packed_w2_tensor.size(3) : -1;
486+
if (!is_woq && N < 2 * BLOCK_N) {
487+
BLOCK_N = T_BLOCK_N2;
488+
}
485489
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
486490
const int MB = div_up(num_tokens_post_pad, BLOCK_M);
487491
const int NB = div_up(N, BLOCK_N);
@@ -712,9 +716,12 @@ void fused_experts_kernel_impl(
712716
if (is_woq) {
713717
silu_and_mul<scalar_t, Q_BLOCK_N>(
714718
ic1 + offset * N + nb * BLOCK_N, C0_f, C1_f, m_size, N);
715-
} else {
719+
} else if (BLOCK_N == T_BLOCK_N) {
716720
silu_and_mul<scalar_t, T_BLOCK_N>(
717721
ic1 + offset * N + nb * BLOCK_N, C0_f, C1_f, m_size, N);
722+
} else {
723+
silu_and_mul<scalar_t, T_BLOCK_N2>(
724+
ic1 + offset * N + nb * BLOCK_N, C0_f, C1_f, m_size, N);
718725
}
719726
if (use_brgemm) {
720727
at::native::cpublas::brgemm_release();

csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1526,7 +1526,10 @@ scale_dot_product_for_indirect_access_kv_cache(
15261526
auto attn_outs_stride_privT = private_attn_outs.stride(0);
15271527
auto attn_outs_stride_privB = private_attn_outs.stride(1);
15281528
auto attn_outs_stride_privH = private_attn_outs.stride(2);
1529-
1529+
if (kv_block_size < seq_len) {
1530+
kv_block_count = max_parallel_parts;
1531+
kv_block_size = (seq_len + kv_block_count - 1) / kv_block_count;
1532+
}
15301533
{
15311534
RECORD_FUNCTION(
15321535
"ipex::iakv_sdp::matmul(attn_w, value)",

csrc/cpu/aten/utils/gemm.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ constexpr int block_size_m() {
1313
constexpr int block_size_n() {
1414
return 8 * TILE_N;
1515
}
16+
constexpr int block_size_n2() {
17+
return 2 * TILE_N;
18+
}
1619

1720
template <typename scalar_t>
1821
inline void copy_stub(

examples/cpu/llm/inference/distributed/run_accuracy_with_deepspeed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ def __init__(
342342
"baichuan",
343343
"baichuan2",
344344
"gptbigcode",
345+
"qwen3moe",
346+
"qwen3",
345347
"qwen",
346348
"yuan",
347349
"jamba",

0 commit comments

Comments
 (0)