Skip to content

Commit 803dac2

Browse files
authored
vulkan: use vec dot for matrix matrix multiplications (#16056)
* vulkan: Change the mul_mm shared memory and register caching system to use vec2 instead of scalars, to enable using dot2 instructions * use fma instead of dot to fix Nvidia and Apple performance issues
1 parent 459c0c2 commit 803dac2

File tree

4 files changed

+187
-199
lines changed

4 files changed

+187
-199
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
#include "types.comp"
3232

3333
#ifndef LOAD_VEC_A
34-
#define LOAD_VEC_A 1
34+
#define LOAD_VEC_A 2
3535
#endif
3636
#ifndef LOAD_VEC_B
37-
#define LOAD_VEC_B 1
37+
#define LOAD_VEC_B 2
3838
#endif
3939

4040
#if !defined(TO_FLOAT_TYPE)
@@ -98,13 +98,13 @@ layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
9898
layout (constant_id = 10) const uint WARP = 32;
9999

100100
#ifdef COOPMAT
101-
#define SHMEM_STRIDE (BK + 8)
101+
#define SHMEM_STRIDE (BK / 2 + 4)
102102
#else
103-
#define SHMEM_STRIDE (BK + 1)
103+
#define SHMEM_STRIDE (BK / 2 + 1)
104104
#endif
105105

106-
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
107-
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
106+
shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
107+
shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
108108

109109
#define NUM_WARPS (BLOCK_SIZE / WARP)
110110

@@ -302,8 +302,8 @@ void main() {
302302
}
303303
#else
304304
ACC_TYPE sums[WMITER * TM * WNITER * TN];
305-
FLOAT_TYPE cache_a[WMITER * TM];
306-
FLOAT_TYPE cache_b[TN];
305+
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
306+
FLOAT_TYPE_VEC2 cache_b[TN];
307307

308308
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
309309
sums[i] = ACC_TYPE(0.0f);
@@ -312,13 +312,13 @@ void main() {
312312

313313
for (uint block = start_k; block < end_k; block += BK) {
314314
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
315-
load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block + loadr_a, end_k);
315+
load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k);
316316
}
317317
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
318318
#if !defined(MUL_MAT_ID)
319-
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block + loadr_b, end_k);
319+
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k);
320320
#else
321-
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block + loadr_b, end_k);
321+
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k);
322322
#endif
323323
}
324324

@@ -331,17 +331,17 @@ void main() {
331331
[[unroll]] for (uint i = 0; i < BK; i += TK) {
332332
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
333333
// Load from shared into cache
334-
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
334+
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
335335

336336
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
337-
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
337+
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
338338

339339
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
340340
}
341341
}
342342
}
343343
#else
344-
[[unroll]] for (uint i = 0; i < BK; i++) {
344+
[[unroll]] for (uint i = 0; i < BK / 2; i++) {
345345
// Load from shared into cache
346346
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
347347
[[unroll]] for (uint j = 0; j < TM; j++) {
@@ -357,7 +357,7 @@ void main() {
357357
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
358358
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
359359
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
360-
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
360+
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
361361
}
362362
}
363363
}

0 commit comments

Comments
 (0)