Skip to content

Commit 34fab8e

Browse files
[release/2.4] Port WOQ optimizations for llama3.1 (#3156)
* WOQ: Cache INT8 weight for weight=int4 lowp-mode=INT8 for large batch (#3012) * WOQ: Cache INT8 weight for weight=int4 lowp-mode=INT8 for large batch * prefetch weight in microkernel * Fix UT * refine UT * Fix format issues * WOQ: Linear M pad to even; refine block_m (#3051) * WOQ: Linear M pad to even; refine block_m * Add comment and refine code * add tpp ACB loop (#3107) (#3114) --------- Co-authored-by: jianan-gu <[email protected]>
1 parent c8abce7 commit 34fab8e

File tree

12 files changed

+784
-358
lines changed

12 files changed

+784
-358
lines changed

csrc/cpu/aten/Linear.cpp

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,11 @@ at::Tensor woq_linear_pack_weight(
370370
// For TPP kernel, we only consider even K
371371
if (K % 2 == 0) {
372372
size_t block_n = 32;
373-
size_t default_block_k = (weight_dtype == WOQ_DTYPE_INT4 && lowp_mode == 3) ? 128 : 64;
374-
size_t block_k = group_size > 0 ? std::min((size_t)group_size, default_block_k) : default_block_k;
373+
size_t default_block_k =
374+
(weight_dtype == WOQ_DTYPE_INT4 && lowp_mode == 3) ? 128 : 64;
375+
size_t block_k = group_size > 0
376+
? std::min((size_t)group_size, default_block_k)
377+
: default_block_k;
375378
while (K % block_k != 0) {
376379
block_k /= 2;
377380
}
@@ -418,6 +421,7 @@ at::Tensor woq_linear_unpack_weight(
418421
return woq_tpp_gemm_unpackB_stub(kCPU, weight, weight_dtype, lowp_mode);
419422
}
420423

424+
#define PAD_M_THRESHOLD 32
421425
IPEX_DEFINE_DISPATCH(woq_tpp_gemm_kernel_stub);
422426
at::Tensor woq_linear_kernel(
423427
const at::Tensor& self,
@@ -428,11 +432,22 @@ at::Tensor woq_linear_kernel(
428432
const std::vector<at::Tensor>& bias_list,
429433
int64_t group_size,
430434
int64_t lowp_mode,
431-
int64_t act_quant_mode) {
435+
int64_t act_quant_mode,
436+
const c10::optional<at::Tensor>& compensation) {
432437
int64_t quant_w_mode = group_size > 0 ? 1 : 0;
433-
return woq_tpp_gemm_kernel_stub(
438+
auto K = self.size(-1);
439+
auto M = self.numel() / K;
440+
auto in = self;
441+
// Big performance regression is found with some odd M
442+
// So, we pad M to the nearest even number
443+
bool m_padded = false;
444+
if (M >= PAD_M_THRESHOLD && M % 2 == 1) {
445+
in = at::pad(in.view({M, K}), {0, 0, 0, 1}, "constant", 0);
446+
m_padded = true;
447+
}
448+
auto y = woq_tpp_gemm_kernel_stub(
434449
kCPU,
435-
self,
450+
in,
436451
weight,
437452
scales_list,
438453
zps_list,
@@ -443,7 +458,14 @@ at::Tensor woq_linear_kernel(
443458
std::vector<at::Tensor>(),
444459
act_quant_mode,
445460
quant_w_mode,
446-
group_size);
461+
group_size,
462+
compensation);
463+
if (m_padded) {
464+
auto out_size = self.sizes().vec();
465+
out_size.back() = y.size(-1);
466+
y = y.narrow(0, 0, M).view(out_size);
467+
}
468+
return y;
447469
}
448470

449471
at::Tensor woq_linear_forward(
@@ -466,7 +488,8 @@ at::Tensor woq_linear_unary_kernel(
466488
const c10::optional<c10::string_view>& algorithm,
467489
int64_t group_size,
468490
int64_t lowp_mode,
469-
int64_t act_quant_mode) {
491+
int64_t act_quant_mode,
492+
const c10::optional<at::Tensor>& compensation) {
470493
int64_t post_op_fusion_type = WOQ_FUSE_NONE;
471494
if (post_op == "gelu") {
472495
if (algorithm == "none") {
@@ -480,9 +503,19 @@ at::Tensor woq_linear_unary_kernel(
480503
post_op_fusion_type = WOQ_FUSE_SILU;
481504
}
482505
int64_t quant_w_mode = group_size > 0 ? 1 : 0;
483-
return woq_tpp_gemm_kernel_stub(
506+
auto K = self.size(-1);
507+
auto M = self.numel() / K;
508+
auto in = self;
509+
// Big performance regression is found with some odd M
510+
// So, we pad M to the nearest even number
511+
bool m_padded = false;
512+
if (M >= PAD_M_THRESHOLD && M % 2 == 1) {
513+
in = at::pad(in.view({M, K}), {0, 0, 0, 1}, "constant", 0);
514+
m_padded = true;
515+
}
516+
auto y = woq_tpp_gemm_kernel_stub(
484517
kCPU,
485-
self,
518+
in,
486519
weight,
487520
scales_list,
488521
zps_list,
@@ -493,7 +526,14 @@ at::Tensor woq_linear_unary_kernel(
493526
std::vector<at::Tensor>(),
494527
act_quant_mode,
495528
quant_w_mode,
496-
group_size);
529+
group_size,
530+
compensation);
531+
if (m_padded) {
532+
auto out_size = self.sizes().vec();
533+
out_size.back() = y.size(-1);
534+
y = y.narrow(0, 0, M).view(out_size);
535+
}
536+
return y;
497537
}
498538

499539
at::Tensor woq_linear_gelu_forward(
@@ -541,7 +581,8 @@ at::Tensor woq_linear_binary_kernel(
541581
int64_t lowp_mode,
542582
const c10::string_view& post_op,
543583
const std::vector<at::Tensor>& others,
544-
int64_t act_quant_mode) {
584+
int64_t act_quant_mode,
585+
const c10::optional<at::Tensor>& compensation) {
545586
int64_t post_op_fusion_type = WOQ_FUSE_NONE;
546587
if (post_op == "add") {
547588
post_op_fusion_type = WOQ_FUSE_ADD;
@@ -551,9 +592,19 @@ at::Tensor woq_linear_binary_kernel(
551592
post_op_fusion_type = WOQ_FUSE_MUL;
552593
}
553594
int64_t quant_w_mode = group_size > 0 ? 1 : 0;
554-
return woq_tpp_gemm_kernel_stub(
595+
auto K = self.size(-1);
596+
auto M = self.numel() / K;
597+
auto in = self;
598+
// Big performance regression is found with some odd M
599+
// So, we pad M to the nearest even number
600+
bool m_padded = false;
601+
if (M >= PAD_M_THRESHOLD && M % 2 == 1) {
602+
in = at::pad(in.view({M, K}), {0, 0, 0, 1}, "constant", 0);
603+
m_padded = true;
604+
}
605+
auto y = woq_tpp_gemm_kernel_stub(
555606
kCPU,
556-
self,
607+
in,
557608
weight,
558609
scales_list,
559610
zps_list,
@@ -564,7 +615,14 @@ at::Tensor woq_linear_binary_kernel(
564615
others,
565616
act_quant_mode,
566617
quant_w_mode,
567-
group_size);
618+
group_size,
619+
compensation);
620+
if (m_padded) {
621+
auto out_size = self.sizes().vec();
622+
out_size.back() = y.size(-1);
623+
y = y.narrow(0, 0, M).view(out_size);
624+
}
625+
return y;
568626
}
569627

570628
at::Tensor woq_linear_add_forward(

csrc/cpu/aten/Linear.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ at::Tensor woq_linear_kernel(
136136
const std::vector<at::Tensor>& bias_list,
137137
int64_t group_size,
138138
int64_t lowp_mode,
139-
int64_t act_quant_mode);
139+
int64_t act_quant_mode,
140+
const c10::optional<at::Tensor>& compensation = c10::nullopt);
140141

141142
at::Tensor woq_linear_unary_kernel(
142143
const at::Tensor& self,
@@ -150,7 +151,8 @@ at::Tensor woq_linear_unary_kernel(
150151
const c10::optional<c10::string_view>& algorithm,
151152
int64_t group_size,
152153
int64_t lowp_mode,
153-
int64_t act_quant_mode);
154+
int64_t act_quant_mode,
155+
const c10::optional<at::Tensor>& compensation = c10::nullopt);
154156

155157
at::Tensor woq_linear_binary_kernel(
156158
const at::Tensor& self,
@@ -163,7 +165,8 @@ at::Tensor woq_linear_binary_kernel(
163165
int64_t lowp_mode,
164166
const c10::string_view& post_op,
165167
const std::vector<at::Tensor>& others,
166-
int64_t act_quant_mode);
168+
int64_t act_quant_mode,
169+
const c10::optional<at::Tensor>& compensation = c10::nullopt);
167170

168171
namespace {
169172
void woq_gemm_kernel_impl(
@@ -239,16 +242,28 @@ using woq_tpp_gemm_kernel_fn = at::Tensor (*)(
239242
const std::vector<at::Tensor>&,
240243
int64_t,
241244
int64_t,
242-
int64_t);
245+
int64_t,
246+
const c10::optional<at::Tensor>&);
243247

244248
using woq_tpp_gemm_packB_fn =
245249
at::Tensor (*)(const at::Tensor&, int, size_t, size_t, int64_t);
246250

247251
using woq_tpp_gemm_unpackB_fn = at::Tensor (*)(const at::Tensor&, int, int64_t);
248252

253+
using woq_dequant_int4_to_int8_packed_fn = at::Tensor (*)(
254+
const at::Tensor&,
255+
const at::Tensor&,
256+
const at::Tensor&,
257+
int,
258+
int64_t,
259+
at::Tensor&);
260+
249261
IPEX_DECLARE_DISPATCH(woq_tpp_gemm_kernel_fn, woq_tpp_gemm_kernel_stub);
250262
IPEX_DECLARE_DISPATCH(woq_tpp_gemm_packB_fn, woq_tpp_gemm_packB_stub);
251263
IPEX_DECLARE_DISPATCH(woq_tpp_gemm_unpackB_fn, woq_tpp_gemm_unpackB_stub);
264+
IPEX_DECLARE_DISPATCH(
265+
woq_dequant_int4_to_int8_packed_fn,
266+
woq_dequant_int4_to_int8_packed_stub);
252267

253268
// Fusion types
254269
#define WOQ_FUSE_NONE 0x0

0 commit comments

Comments
 (0)