@@ -1520,22 +1520,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1520
1520
!ggml_is_transposed (op->src [1 ]) &&
1521
1521
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1522
1522
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1523
- props_dev->has_simdgroup_mm &&
1524
- op->src [1 ]->type == GGML_TYPE_F32 &&
1525
- ne00 % 32 == 0 && ne00 >= 64 &&
1523
+ props_dev->has_simdgroup_mm && ne00 >= 64 &&
1526
1524
(ne11 > ne11_mm_min || (ggml_is_quantized (op->src [0 ]->type ) && ne12 > 1 ))) {
1527
1525
// printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1528
1526
1529
1527
// some Metal matrix data types require aligned pointers
1530
1528
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1531
- switch (op->src [0 ]->type ) {
1532
- case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1533
- case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1534
- case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1535
- default : break ;
1536
- }
1529
+ // switch (op->src[0]->type) {
1530
+ // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1531
+ // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1532
+ // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1533
+ // default: break;
1534
+ // }
1537
1535
1538
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm (lib, op-> src [ 0 ]-> type , op-> src [ 1 ]-> type );
1536
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm (lib, op);
1539
1537
1540
1538
ggml_metal_kargs_mul_mm args = {
1541
1539
/* .ne00 =*/ ne00,
@@ -1655,8 +1653,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1655
1653
GGML_ASSERT (!ggml_is_transposed (op->src [0 ]));
1656
1654
GGML_ASSERT (!ggml_is_transposed (op->src [1 ]));
1657
1655
1658
- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
1659
-
1660
1656
GGML_ASSERT (ne03 == 1 );
1661
1657
GGML_ASSERT (ne13 == 1 );
1662
1658
@@ -1674,19 +1670,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1674
1670
// ne21 = n_rows (batch size)
1675
1671
const int ne21_mm_id_min = 32 ;
1676
1672
1677
- if (props_dev->has_simdgroup_mm &&
1678
- ne00 % 32 == 0 && ne00 >= 64 &&
1679
- (ne21 >= ne21_mm_id_min)) {
1680
- GGML_ASSERT (ne00 % 4 == 0 );
1681
-
1673
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
1682
1674
// some Metal matrix data types require aligned pointers
1683
1675
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1684
- switch (op->src [0 ]->type ) {
1685
- case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1686
- case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1687
- case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1688
- default : break ;
1689
- }
1676
+ // switch (op->src[0]->type) {
1677
+ // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1678
+ // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1679
+ // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1680
+ // default: break;
1681
+ // }
1690
1682
1691
1683
// extra buffers for intermediate id mapping
1692
1684
ggml_metal_buffer_id bid_tpe = bid_dst;
@@ -1730,7 +1722,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1730
1722
ggml_metal_op_concurrency_reset (ctx);
1731
1723
1732
1724
{
1733
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id (lib, op-> src [ 0 ]-> type , GGML_TYPE_F16 );
1725
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id (lib, op);
1734
1726
1735
1727
ggml_metal_kargs_mul_mm_id args = {
1736
1728
/* .ne00 =*/ ne00,
0 commit comments