@@ -370,8 +370,11 @@ at::Tensor woq_linear_pack_weight(
370
370
// For TPP kernel, we only consider even K
371
371
if (K % 2 == 0 ) {
372
372
size_t block_n = 32 ;
373
- size_t default_block_k = (weight_dtype == WOQ_DTYPE_INT4 && lowp_mode == 3 ) ? 128 : 64 ;
374
- size_t block_k = group_size > 0 ? std::min ((size_t )group_size, default_block_k) : default_block_k;
373
+ size_t default_block_k =
374
+ (weight_dtype == WOQ_DTYPE_INT4 && lowp_mode == 3 ) ? 128 : 64 ;
375
+ size_t block_k = group_size > 0
376
+ ? std::min ((size_t )group_size, default_block_k)
377
+ : default_block_k;
375
378
while (K % block_k != 0 ) {
376
379
block_k /= 2 ;
377
380
}
@@ -418,6 +421,7 @@ at::Tensor woq_linear_unpack_weight(
418
421
return woq_tpp_gemm_unpackB_stub (kCPU , weight, weight_dtype, lowp_mode);
419
422
}
420
423
424
+ #define PAD_M_THRESHOLD 32
421
425
IPEX_DEFINE_DISPATCH (woq_tpp_gemm_kernel_stub);
422
426
at::Tensor woq_linear_kernel (
423
427
const at::Tensor& self,
@@ -428,11 +432,22 @@ at::Tensor woq_linear_kernel(
428
432
const std::vector<at::Tensor>& bias_list,
429
433
int64_t group_size,
430
434
int64_t lowp_mode,
431
- int64_t act_quant_mode) {
435
+ int64_t act_quant_mode,
436
+ const c10::optional<at::Tensor>& compensation) {
432
437
int64_t quant_w_mode = group_size > 0 ? 1 : 0 ;
433
- return woq_tpp_gemm_kernel_stub (
438
+ auto K = self.size (-1 );
439
+ auto M = self.numel () / K;
440
+ auto in = self;
441
+ // Big performance regression is found with some odd M
442
+ // So, we pad M to the nearest even number
443
+ bool m_padded = false ;
444
+ if (M >= PAD_M_THRESHOLD && M % 2 == 1 ) {
445
+ in = at::pad (in.view ({M, K}), {0 , 0 , 0 , 1 }, " constant" , 0 );
446
+ m_padded = true ;
447
+ }
448
+ auto y = woq_tpp_gemm_kernel_stub (
434
449
kCPU ,
435
- self ,
450
+ in ,
436
451
weight,
437
452
scales_list,
438
453
zps_list,
@@ -443,7 +458,14 @@ at::Tensor woq_linear_kernel(
443
458
std::vector<at::Tensor>(),
444
459
act_quant_mode,
445
460
quant_w_mode,
446
- group_size);
461
+ group_size,
462
+ compensation);
463
+ if (m_padded) {
464
+ auto out_size = self.sizes ().vec ();
465
+ out_size.back () = y.size (-1 );
466
+ y = y.narrow (0 , 0 , M).view (out_size);
467
+ }
468
+ return y;
447
469
}
448
470
449
471
at::Tensor woq_linear_forward (
@@ -466,7 +488,8 @@ at::Tensor woq_linear_unary_kernel(
466
488
const c10::optional<c10::string_view>& algorithm,
467
489
int64_t group_size,
468
490
int64_t lowp_mode,
469
- int64_t act_quant_mode) {
491
+ int64_t act_quant_mode,
492
+ const c10::optional<at::Tensor>& compensation) {
470
493
int64_t post_op_fusion_type = WOQ_FUSE_NONE;
471
494
if (post_op == " gelu" ) {
472
495
if (algorithm == " none" ) {
@@ -480,9 +503,19 @@ at::Tensor woq_linear_unary_kernel(
480
503
post_op_fusion_type = WOQ_FUSE_SILU;
481
504
}
482
505
int64_t quant_w_mode = group_size > 0 ? 1 : 0 ;
483
- return woq_tpp_gemm_kernel_stub (
506
+ auto K = self.size (-1 );
507
+ auto M = self.numel () / K;
508
+ auto in = self;
509
+ // Big performance regression is found with some odd M
510
+ // So, we pad M to the nearest even number
511
+ bool m_padded = false ;
512
+ if (M >= PAD_M_THRESHOLD && M % 2 == 1 ) {
513
+ in = at::pad (in.view ({M, K}), {0 , 0 , 0 , 1 }, " constant" , 0 );
514
+ m_padded = true ;
515
+ }
516
+ auto y = woq_tpp_gemm_kernel_stub (
484
517
kCPU ,
485
- self ,
518
+ in ,
486
519
weight,
487
520
scales_list,
488
521
zps_list,
@@ -493,7 +526,14 @@ at::Tensor woq_linear_unary_kernel(
493
526
std::vector<at::Tensor>(),
494
527
act_quant_mode,
495
528
quant_w_mode,
496
- group_size);
529
+ group_size,
530
+ compensation);
531
+ if (m_padded) {
532
+ auto out_size = self.sizes ().vec ();
533
+ out_size.back () = y.size (-1 );
534
+ y = y.narrow (0 , 0 , M).view (out_size);
535
+ }
536
+ return y;
497
537
}
498
538
499
539
at::Tensor woq_linear_gelu_forward (
@@ -541,7 +581,8 @@ at::Tensor woq_linear_binary_kernel(
541
581
int64_t lowp_mode,
542
582
const c10::string_view& post_op,
543
583
const std::vector<at::Tensor>& others,
544
- int64_t act_quant_mode) {
584
+ int64_t act_quant_mode,
585
+ const c10::optional<at::Tensor>& compensation) {
545
586
int64_t post_op_fusion_type = WOQ_FUSE_NONE;
546
587
if (post_op == " add" ) {
547
588
post_op_fusion_type = WOQ_FUSE_ADD;
@@ -551,9 +592,19 @@ at::Tensor woq_linear_binary_kernel(
551
592
post_op_fusion_type = WOQ_FUSE_MUL;
552
593
}
553
594
int64_t quant_w_mode = group_size > 0 ? 1 : 0 ;
554
- return woq_tpp_gemm_kernel_stub (
595
+ auto K = self.size (-1 );
596
+ auto M = self.numel () / K;
597
+ auto in = self;
598
+ // Big performance regression is found with some odd M
599
+ // So, we pad M to the nearest even number
600
+ bool m_padded = false ;
601
+ if (M >= PAD_M_THRESHOLD && M % 2 == 1 ) {
602
+ in = at::pad (in.view ({M, K}), {0 , 0 , 0 , 1 }, " constant" , 0 );
603
+ m_padded = true ;
604
+ }
605
+ auto y = woq_tpp_gemm_kernel_stub (
555
606
kCPU ,
556
- self ,
607
+ in ,
557
608
weight,
558
609
scales_list,
559
610
zps_list,
@@ -564,7 +615,14 @@ at::Tensor woq_linear_binary_kernel(
564
615
others,
565
616
act_quant_mode,
566
617
quant_w_mode,
567
- group_size);
618
+ group_size,
619
+ compensation);
620
+ if (m_padded) {
621
+ auto out_size = self.sizes ().vec ();
622
+ out_size.back () = y.size (-1 );
623
+ y = y.narrow (0 , 0 , M).view (out_size);
624
+ }
625
+ return y;
568
626
}
569
627
570
628
at::Tensor woq_linear_add_forward (
0 commit comments