31
31
#include "types.comp"
32
32
33
33
#ifndef LOAD_VEC_A
34
- #define LOAD_VEC_A 1
34
+ #define LOAD_VEC_A 2
35
35
#endif
36
36
#ifndef LOAD_VEC_B
37
- #define LOAD_VEC_B 1
37
+ #define LOAD_VEC_B 2
38
38
#endif
39
39
40
40
#if !defined(TO_FLOAT_TYPE)
@@ -98,13 +98,13 @@ layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
98
98
layout (constant_id = 10) const uint WARP = 32;
99
99
100
100
#ifdef COOPMAT
101
- #define SHMEM_STRIDE (BK + 8 )
101
+ #define SHMEM_STRIDE (BK / 2 + 4 )
102
102
#else
103
- #define SHMEM_STRIDE (BK + 1)
103
+ #define SHMEM_STRIDE (BK / 2 + 1)
104
104
#endif
105
105
106
- shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
107
- shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
106
+ shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
107
+ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
108
108
109
109
#define NUM_WARPS (BLOCK_SIZE / WARP)
110
110
@@ -302,8 +302,8 @@ void main() {
302
302
}
303
303
#else
304
304
ACC_TYPE sums[WMITER * TM * WNITER * TN];
305
- FLOAT_TYPE cache_a[WMITER * TM];
306
- FLOAT_TYPE cache_b[TN];
305
+ FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
306
+ FLOAT_TYPE_VEC2 cache_b[TN];
307
307
308
308
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
309
309
sums[i] = ACC_TYPE(0.0f);
@@ -312,13 +312,13 @@ void main() {
312
312
313
313
for (uint block = start_k; block < end_k; block += BK) {
314
314
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
315
- load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block + loadr_a , end_k);
315
+ load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k);
316
316
}
317
317
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
318
318
#if !defined(MUL_MAT_ID)
319
- load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block + loadr_b , end_k);
319
+ load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k);
320
320
#else
321
- load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block + loadr_b , end_k);
321
+ load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k);
322
322
#endif
323
323
}
324
324
@@ -331,17 +331,17 @@ void main() {
331
331
[[unroll]] for (uint i = 0; i < BK; i += TK) {
332
332
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
333
333
// Load from shared into cache
334
- coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
334
+ coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2 , SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
335
335
336
336
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
337
- coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
337
+ coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2 , SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
338
338
339
339
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
340
340
}
341
341
}
342
342
}
343
343
#else
344
- [[unroll]] for (uint i = 0; i < BK; i++) {
344
+ [[unroll]] for (uint i = 0; i < BK / 2 ; i++) {
345
345
// Load from shared into cache
346
346
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
347
347
[[unroll]] for (uint j = 0; j < TM; j++) {
@@ -357,7 +357,7 @@ void main() {
357
357
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
358
358
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
359
359
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
360
- sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
360
+ sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x ), ACC_TYPE(cache_b[cc].x ), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]) );
361
361
}
362
362
}
363
363
}
0 commit comments