Skip to content

Commit 6a2c614

Browse files
authored
metal : extend mat-mat multiplication support (#16225)
* metal : support mul_mm with src1->type == GGML_TYPE_F16 * metal : support mul_mm_id with src1->type == GGML_TYPE_F16 [no ci] * metal : mul_mm support ne00 % 32 != 0 * metal : support mul_mm_id with ne00 % 32 != 0 * cont : remove unnecessary unrolls * cont : simplify data loading * metal : optimize mul_mm when output bounds checks are not needed
1 parent 3b53634 commit 6a2c614

File tree

7 files changed

+248
-121
lines changed

7 files changed

+248
-121
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -438,21 +438,35 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_libr
438438
return res;
439439
}
440440

441-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
441+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
442442
char base[256];
443443
char name[256];
444444

445+
const ggml_type tsrc0 = op->src[0]->type;
446+
const ggml_type tsrc1 = op->src[1]->type;
447+
448+
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
449+
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
450+
445451
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
446-
snprintf(name, 256, "%s", base);
452+
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
447453

448454
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
449455
if (res) {
450456
return res;
451457
}
452458

453-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
459+
ggml_metal_cv_t cv = ggml_metal_cv_init();
454460

455-
ggml_metal_pipeline_set_smem(res, 8192);
461+
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
462+
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
463+
464+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
465+
466+
ggml_metal_cv_free(cv);
467+
468+
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
469+
ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
456470

457471
return res;
458472
}
@@ -659,19 +673,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
659673
return res;
660674
}
661675

662-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
676+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
663677
char base[256];
664678
char name[256];
665679

680+
const ggml_type tsrc0 = op->src[0]->type;
681+
const ggml_type tsrc1 = op->src[1]->type;
682+
683+
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
684+
666685
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
667-
snprintf(name, 256, "%s", base);
686+
snprintf(name, 256, "%s_bci=%d", base, bc_inp);
668687

669688
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
670689
if (res) {
671690
return res;
672691
}
673692

674-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
693+
ggml_metal_cv_t cv = ggml_metal_cv_init();
694+
695+
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
696+
697+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
698+
699+
ggml_metal_cv_free(cv);
675700

676701
ggml_metal_pipeline_set_smem(res, 8192);
677702

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_me
115115
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
116116
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
117117
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
118-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
118+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
119119
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
120120
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
121-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
121+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
122122
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
123123
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
124124
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
717717
return true;
718718
case GGML_OP_MUL_MAT:
719719
case GGML_OP_MUL_MAT_ID:
720-
return has_simdgroup_reduction &&
721-
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
720+
return has_simdgroup_reduction;
722721
case GGML_OP_CPY:
723722
case GGML_OP_DUP:
724723
case GGML_OP_CONT:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
#define FC_FLASH_ATTN_EXT_VEC 200
7777
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
7878
#define FC_MUL_MV 400
79+
#define FC_MUL_MM 500
7980

8081
// kernel argument structs
8182
//

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,22 +1520,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
15201520
!ggml_is_transposed(op->src[1]) &&
15211521
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
15221522
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1523-
props_dev->has_simdgroup_mm &&
1524-
op->src[1]->type == GGML_TYPE_F32 &&
1525-
ne00 % 32 == 0 && ne00 >= 64 &&
1523+
props_dev->has_simdgroup_mm && ne00 >= 64 &&
15261524
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
15271525
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
15281526

15291527
// some Metal matrix data types require aligned pointers
15301528
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1531-
switch (op->src[0]->type) {
1532-
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1533-
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1534-
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1535-
default: break;
1536-
}
1529+
//switch (op->src[0]->type) {
1530+
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1531+
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1532+
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1533+
// default: break;
1534+
//}
15371535

1538-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type);
1536+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
15391537

15401538
ggml_metal_kargs_mul_mm args = {
15411539
/*.ne00 =*/ ne00,
@@ -1655,8 +1653,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
16551653
GGML_ASSERT(!ggml_is_transposed(op->src[0]));
16561654
GGML_ASSERT(!ggml_is_transposed(op->src[1]));
16571655

1658-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1659-
16601656
GGML_ASSERT(ne03 == 1);
16611657
GGML_ASSERT(ne13 == 1);
16621658

@@ -1674,19 +1670,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
16741670
// ne21 = n_rows (batch size)
16751671
const int ne21_mm_id_min = 32;
16761672

1677-
if (props_dev->has_simdgroup_mm &&
1678-
ne00 % 32 == 0 && ne00 >= 64 &&
1679-
(ne21 >= ne21_mm_id_min)) {
1680-
GGML_ASSERT(ne00 % 4 == 0);
1681-
1673+
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
16821674
// some Metal matrix data types require aligned pointers
16831675
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1684-
switch (op->src[0]->type) {
1685-
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1686-
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1687-
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1688-
default: break;
1689-
}
1676+
//switch (op->src[0]->type) {
1677+
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1678+
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1679+
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1680+
// default: break;
1681+
//}
16901682

16911683
// extra buffers for intermediate id mapping
16921684
ggml_metal_buffer_id bid_tpe = bid_dst;
@@ -1730,7 +1722,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
17301722
ggml_metal_op_concurrency_reset(ctx);
17311723

17321724
{
1733-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, GGML_TYPE_F16);
1725+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
17341726

17351727
ggml_metal_kargs_mul_mm_id args = {
17361728
/*.ne00 =*/ ne00,

0 commit comments

Comments
 (0)