@@ -910,6 +910,13 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
910
910
((ggml_tensor*)dst->extra )->ne );
911
911
return ;
912
912
}
913
+ if (dst->type == GGML_TYPE_Q4_0) {
914
+ aclrtlaunch_ascendc_quantize_f16_to_q4_0 (
915
+ 24 , ctx.stream (), src->data , dst->data ,
916
+ ((ggml_tensor*)src->extra )->ne , ((ggml_tensor*)src->extra )->nb ,
917
+ ((ggml_tensor*)dst->extra )->ne );
918
+ return ;
919
+ }
913
920
if (dst->type == GGML_TYPE_F16) {
914
921
if (ggml_are_same_shape (src, dst)) {
915
922
cann_copy (ctx, acl_src, acl_dst);
@@ -971,6 +978,13 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
971
978
((ggml_tensor*)dst->extra )->ne );
972
979
return ;
973
980
}
981
+ if (dst->type == GGML_TYPE_Q4_0) {
982
+ aclrtlaunch_ascendc_quantize_f32_to_q4_0 (
983
+ 24 , ctx.stream (), src->data , dst->data ,
984
+ ((ggml_tensor*)src->extra )->ne , ((ggml_tensor*)src->extra )->nb ,
985
+ ((ggml_tensor*)dst->extra )->ne );
986
+ return ;
987
+ }
974
988
if (dst->type == GGML_TYPE_F32) {
975
989
if (ggml_are_same_shape (src, dst)) {
976
990
cann_copy (ctx, acl_src, acl_dst);
@@ -2463,21 +2477,33 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
2463
2477
* @param dst The destination tensor where the result of the matrix
2464
2478
* multiplication will be stored.
2465
2479
*/
2466
- static void ggml_cann_mul_mat_q8_0 (ggml_backend_cann_context& ctx,
2467
- ggml_tensor* dst) {
2480
+ static void ggml_cann_mul_mat_quant (ggml_backend_cann_context& ctx,
2481
+ ggml_tensor* dst,
2482
+ const enum ggml_type type) {
2468
2483
ggml_tensor* src0 = dst->src [0 ]; // weight
2469
2484
ggml_tensor* src1 = dst->src [1 ]; // input
2470
2485
2471
2486
// The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
2472
2487
// is regarded as batch. weight need transpose.
2473
2488
int64_t weight_ne[] = {src0->ne [1 ], src0->ne [0 ]};
2474
- size_t weight_elem_size = sizeof (uint8_t );
2475
- size_t weight_nb[] = {weight_elem_size * src0->ne [0 ], weight_elem_size};
2489
+ float weight_elem_size;
2490
+ if (type == GGML_TYPE_Q4_0) {
2491
+ weight_elem_size = float (sizeof (uint8_t )) / 2 ;
2492
+ }
2493
+ else if (type == GGML_TYPE_Q8_0) {
2494
+ weight_elem_size = float (sizeof (uint8_t ));
2495
+ }
2496
+ else {
2497
+ GGML_ABORT (" Only support Q4_0 and Q8_0 MUL_MAT" );
2498
+ }
2499
+ float weight_nb[] = {weight_elem_size * src0->ne [0 ], weight_elem_size};
2500
+
2476
2501
// size of one matrix is element_size * height * width.
2477
2502
size_t weight_stride = weight_elem_size * src0->ne [0 ] * src0->ne [1 ];
2478
2503
size_t weight_size = weight_stride * src0->ne [2 ] * src0->ne [3 ];
2479
2504
2480
2505
// scale stored at the end of weight. Also need transpose.
2506
+ GGML_ASSERT (QK4_0 == QK8_0);
2481
2507
int64_t scale_ne[] = {src0->ne [1 ], src0->ne [0 ] / QK8_0};
2482
2508
size_t scale_elem_size = sizeof (uint16_t );
2483
2509
size_t scale_nb[] = {src0->ne [0 ] / QK8_0 * scale_elem_size,
@@ -2541,8 +2567,9 @@ static void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx,
2541
2567
(char *)input_buffer + batch1 * input_stride, ACL_FLOAT16,
2542
2568
input_elem_size, input_ne, input_nb, 2 );
2543
2569
aclTensor* acl_weight_tensor = ggml_cann_create_tensor (
2544
- (char *)src0->data + batch0 * weight_stride, ACL_INT8,
2545
- weight_elem_size, weight_ne, weight_nb, 2 );
2570
+ (char *)src0->data + batch0 * weight_stride,
2571
+ ggml_cann_type_mapping (type), weight_elem_size, weight_ne,
2572
+ weight_nb, 2 );
2546
2573
aclTensor* acl_scale_tensor = ggml_cann_create_tensor (
2547
2574
scale_offset + batch0 * scale_stride, ACL_FLOAT16,
2548
2575
scale_elem_size, scale_ne, scale_nb, 2 );
@@ -2596,11 +2623,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2596
2623
case GGML_TYPE_F16:
2597
2624
ggml_cann_mat_mul_fp (ctx, dst);
2598
2625
break ;
2599
- // case GGML_TYPE_Q4_0:
2600
- // ggml_cann_mul_mat_q4_0(ctx, dst);
2601
- // break;
2626
+ case GGML_TYPE_Q4_0:
2602
2627
case GGML_TYPE_Q8_0:
2603
- ggml_cann_mul_mat_q8_0 (ctx, dst);
2628
+ ggml_cann_mul_mat_quant (ctx, dst, type );
2604
2629
break ;
2605
2630
default :
2606
2631
GGML_ABORT (" fatal error" );
0 commit comments