@@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
1090
1090
return res;
1091
1091
}
1092
1092
1093
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
1094
- assert (op->op == GGML_OP_RMS_NORM);
1095
-
1096
- GGML_ASSERT (op->src [0 ]->ne [0 ] % 4 == 0 );
1097
- GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
1098
-
1099
- char base[256 ];
1100
- char name[256 ];
1101
-
1102
- switch (n_fuse) {
1103
- case 1 : snprintf (base, 256 , " kernel_rms_norm_f32" ); break ;
1104
- case 2 : snprintf (base, 256 , " kernel_rms_norm_mul_f32" ); break ;
1105
- case 3 : snprintf (base, 256 , " kernel_rms_norm_mul_add_f32" ); break ;
1106
- default : GGML_ABORT (" fatal error" );
1107
- }
1108
-
1109
- snprintf (name, 256 , " %s" , base);
1110
-
1111
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
1112
- if (res) {
1113
- return res;
1114
- }
1115
-
1116
- res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
1117
-
1118
- ggml_metal_pipeline_set_smem (res, 32 *sizeof (float ));
1119
-
1120
- return res;
1121
- }
1122
-
1123
1093
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const ggml_tensor * op) {
1124
1094
assert (op->op == GGML_OP_L2_NORM);
1125
1095
@@ -1167,16 +1137,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
1167
1137
return res;
1168
1138
}
1169
1139
1170
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const ggml_tensor * op) {
1171
- assert (op->op == GGML_OP_NORM);
1140
+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse ) {
1141
+ assert (op->op == GGML_OP_NORM || op-> op == GGML_OP_RMS_NORM );
1172
1142
1173
- GGML_ASSERT (op->src [0 ]->ne [0 ] % 4 == 0 );
1174
- GGML_ASSERT (ggml_is_contiguous_1 (op->src [0 ]));
1143
+ GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
1175
1144
1176
1145
char base[256 ];
1177
1146
char name[256 ];
1178
1147
1179
- snprintf (base, 256 , " kernel_norm_f32" );
1148
+ const char * suffix = " " ;
1149
+ if (op->ne [0 ] % 4 == 0 ) {
1150
+ suffix = " _4" ;
1151
+ }
1152
+
1153
+ switch (op->op ) {
1154
+ case GGML_OP_NORM:
1155
+ switch (n_fuse) {
1156
+ case 1 : snprintf (base, 256 , " kernel_norm_f32%s" , suffix); break ;
1157
+ case 2 : snprintf (base, 256 , " kernel_norm_mul_f32%s" , suffix); break ;
1158
+ case 3 : snprintf (base, 256 , " kernel_norm_mul_add_f32%s" , suffix); break ;
1159
+ default : GGML_ABORT (" fatal error" );
1160
+ } break ;
1161
+ case GGML_OP_RMS_NORM:
1162
+ switch (n_fuse) {
1163
+ case 1 : snprintf (base, 256 , " kernel_rms_norm_f32%s" , suffix); break ;
1164
+ case 2 : snprintf (base, 256 , " kernel_rms_norm_mul_f32%s" , suffix); break ;
1165
+ case 3 : snprintf (base, 256 , " kernel_rms_norm_mul_add_f32%s" , suffix); break ;
1166
+ default : GGML_ABORT (" fatal error" );
1167
+ } break ;
1168
+ default : GGML_ABORT (" fatal error" );
1169
+ }
1170
+
1180
1171
snprintf (name, 256 , " %s" , base);
1181
1172
1182
1173
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
0 commit comments