Skip to content

Commit 6492522

Browse files
ggerganovstruct
authored andcommitted
metal : fuse NORM + MUL + ADD, support non-multiples of 4 (ggml-org#16220)
* metal : fuse NORM + MUL + ADD * metal : support norms of non-multiple of 4 * cont : fix comment [no ci]
1 parent 45898a9 commit 6492522

File tree

9 files changed

+206
-232
lines changed

9 files changed

+206
-232
lines changed

ggml/src/ggml-metal/ggml-metal-common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
383383
// fuse only ops that start with these operations
384384
// can be expanded when needed
385385
if (node.op() == GGML_OP_ADD ||
386+
node.op() == GGML_OP_NORM ||
386387
node.op() == GGML_OP_RMS_NORM) {
387388
ops[0] = node.op();
388389

@@ -392,6 +393,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
392393
// can be expanded when needed
393394
if (gf->nodes[f]->op != GGML_OP_ADD &&
394395
gf->nodes[f]->op != GGML_OP_MUL &&
396+
gf->nodes[f]->op != GGML_OP_NORM &&
395397
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
396398
break;
397399
}

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
10901090
return res;
10911091
}
10921092

1093-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
1094-
assert(op->op == GGML_OP_RMS_NORM);
1095-
1096-
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
1097-
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
1098-
1099-
char base[256];
1100-
char name[256];
1101-
1102-
switch (n_fuse) {
1103-
case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break;
1104-
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break;
1105-
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break;
1106-
default: GGML_ABORT("fatal error");
1107-
}
1108-
1109-
snprintf(name, 256, "%s", base);
1110-
1111-
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1112-
if (res) {
1113-
return res;
1114-
}
1115-
1116-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1117-
1118-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1119-
1120-
return res;
1121-
}
1122-
11231093
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
11241094
assert(op->op == GGML_OP_L2_NORM);
11251095

@@ -1167,16 +1137,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
11671137
return res;
11681138
}
11691139

1170-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1171-
assert(op->op == GGML_OP_NORM);
1140+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
1141+
assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
11721142

1173-
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
1174-
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
1143+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
11751144

11761145
char base[256];
11771146
char name[256];
11781147

1179-
snprintf(base, 256, "kernel_norm_f32");
1148+
const char * suffix = "";
1149+
if (op->ne[0] % 4 == 0) {
1150+
suffix = "_4";
1151+
}
1152+
1153+
switch (op->op) {
1154+
case GGML_OP_NORM:
1155+
switch (n_fuse) {
1156+
case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
1157+
case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
1158+
case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
1159+
default: GGML_ABORT("fatal error");
1160+
} break;
1161+
case GGML_OP_RMS_NORM:
1162+
switch (n_fuse) {
1163+
case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
1164+
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
1165+
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
1166+
default: GGML_ABORT("fatal error");
1167+
} break;
1168+
default: GGML_ABORT("fatal error");
1169+
}
1170+
11801171
snprintf(name, 256, "%s", base);
11811172

11821173
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
123123
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
124124
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
125125
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
126-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
127126
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
128127
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
129-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
128+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
130129
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
131130
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
132131
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -661,13 +661,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
661661
case GGML_OP_SOFT_MAX:
662662
case GGML_OP_GROUP_NORM:
663663
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
664-
case GGML_OP_RMS_NORM:
665664
case GGML_OP_L2_NORM:
666665
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
667666
case GGML_OP_ARGMAX:
668667
return has_simdgroup_reduction;
669668
case GGML_OP_NORM:
670-
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
669+
case GGML_OP_RMS_NORM:
670+
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
671671
case GGML_OP_ROPE:
672672
return true;
673673
case GGML_OP_IM2COL:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -428,16 +428,11 @@ typedef struct {
428428
uint64_t nb1;
429429
} ggml_metal_kargs_mul_mv_id;
430430

431+
// NORM
432+
// RMS_NORM
431433
typedef struct {
432434
int32_t ne00;
433-
int32_t ne00_4;
434-
uint64_t nb01;
435-
float eps;
436-
} ggml_metal_kargs_norm;
437-
438-
typedef struct {
439-
int32_t ne00;
440-
int32_t ne00_4;
435+
int32_t ne00_t;
441436
uint64_t nb1;
442437
uint64_t nb2;
443438
uint64_t nb3;
@@ -448,7 +443,7 @@ typedef struct {
448443
uint64_t nbf1[3];
449444
uint64_t nbf2[3];
450445
uint64_t nbf3[3];
451-
} ggml_metal_kargs_rms_norm;
446+
} ggml_metal_kargs_norm;
452447

453448
typedef struct {
454449
int32_t ne00;

0 commit comments

Comments
 (0)