@@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
23
23
return a / b;
24
24
}
25
25
26
-
27
-
28
- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... src1_ptrs>
29
- static __global__ void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst,
30
- const int ne0, const int ne1, const int ne2, const int ne3,
31
- const int ne10, const int ne11, const int ne12, const int ne13,
32
- /* int s0, */ const int s1, const int s2, const int s3,
33
- /* int s00,*/ const int s01, const int s02, const int s03,
34
- /* int s10,*/ const int s11, const int s12, const int s13,
35
- src1_ptrs... src1s) {
36
- const int i0s = blockDim .x *blockIdx .x + threadIdx .x ;
37
- const int i1 = (blockDim .y *blockIdx .y + threadIdx .y );
38
- const int i2 = (blockDim .z *blockIdx .z + threadIdx .z ) / ne3;
39
- const int i3 = (blockDim .z *blockIdx .z + threadIdx .z ) % ne3;
40
-
41
- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
26
+ template <float (*bin_op)(const float , const float ),
27
+ typename src0_t ,
28
+ typename src1_t ,
29
+ typename dst_t ,
30
+ typename ... src1_ptrs>
31
+ static __global__ void k_bin_bcast (const src0_t * src0,
32
+ const src1_t * src1,
33
+ dst_t * dst,
34
+ const int ne0,
35
+ const int ne1,
36
+ const int ne2,
37
+ const uint3 ne3,
38
+ const uint3 ne10,
39
+ const uint3 ne11,
40
+ const uint3 ne12,
41
+ const uint3 ne13,
42
+ /* int s0, */ const int s1,
43
+ const int s2,
44
+ const int s3,
45
+ /* int s00,*/ const int s01,
46
+ const int s02,
47
+ const int s03,
48
+ /* int s10,*/ const int s11,
49
+ const int s12,
50
+ const int s13,
51
+ src1_ptrs... src1s) {
52
+ const uint32_t i0s = blockDim .x * blockIdx .x + threadIdx .x ;
53
+ const uint32_t i1 = (blockDim .y * blockIdx .y + threadIdx .y );
54
+ const uint32_t i2 = fastdiv ((blockDim .z * blockIdx .z + threadIdx .z ), ne3);
55
+ const uint32_t i3 = (blockDim .z * blockIdx .z + threadIdx .z ) - (i2 * ne3.z );
56
+
57
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z ) {
42
58
return ;
43
59
}
44
60
45
- const int i11 = i1 % ne11;
46
- const int i12 = i2 % ne12;
47
- const int i13 = i3 % ne13;
61
+ const uint32_t i11 = fastmodulo (i1, ne11) ;
62
+ const uint32_t i12 = fastmodulo (i2, ne12) ;
63
+ const uint32_t i13 = fastmodulo (i3, ne13) ;
48
64
49
65
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
50
66
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
53
69
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr ;
54
70
dst_t * dst_row = dst + i_dst;
55
71
56
- for (int i0 = i0s; i0 < ne0; i0 += blockDim .x * gridDim .x ) {
57
- const int i10 = i0 % ne10;
72
+ for (int i0 = i0s; i0 < ne0; i0 += blockDim .x * gridDim .x ) {
73
+ const uint32_t i10 = fastmodulo (i0, ne10) ;
58
74
59
75
float result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
60
76
if constexpr (sizeof ...(src1_ptrs) > 0 ) {
@@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
67
83
}
68
84
}
69
85
70
- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... src1_ptrs>
71
- static __global__ void k_bin_bcast_unravel (const src0_t * src0, const src1_t * src1, dst_t * dst,
72
- const int ne0, const int ne1, const int ne2,const int ne3,
73
- const int ne10, const int ne11, const int ne12, const int ne13,
74
- /* int s0, */ const int s1, const int s2, const int s3,
75
- /* int s00,*/ const int s01, const int s02, const int s03,
76
- /* int s10,*/ const int s11, const int s12, const int s13,
77
- src1_ptrs ... src1s) {
86
+ template <float (*bin_op)(const float , const float ),
87
+ typename src0_t ,
88
+ typename src1_t ,
89
+ typename dst_t ,
90
+ typename ... src1_ptrs>
91
+ static __global__ void k_bin_bcast_unravel (const src0_t * src0,
92
+ const src1_t * src1,
93
+ dst_t * dst,
94
+ const uint3 ne0,
95
+ const uint3 ne1,
96
+ const uint3 ne2,
97
+ const uint32_t ne3,
98
+ const uint3 prod_012,
99
+ const uint3 prod_01,
100
+ const uint3 ne10,
101
+ const uint3 ne11,
102
+ const uint3 ne12,
103
+ const uint3 ne13,
104
+ /* int s0, */ const int s1,
105
+ const int s2,
106
+ const int s3,
107
+ /* int s00,*/ const int s01,
108
+ const int s02,
109
+ const int s03,
110
+ /* int s10,*/ const int s11,
111
+ const int s12,
112
+ const int s13,
113
+ src1_ptrs... src1s) {
78
114
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
79
115
80
- const int i3 = i/(ne2*ne1*ne0 );
81
- const int i2 = (i/(ne1*ne0)) % ne2 ;
82
- const int i1 = (i/ne0) % ne1 ;
83
- const int i0 = i % ne0;
116
+ const uint32_t i3 = fastdiv (i, prod_012 );
117
+ const uint32_t i2 = fastdiv (i - i3 * prod_012. z , prod_01) ;
118
+ const uint32_t i1 = fastdiv (i - i3 * prod_012. z - i2 * prod_01. z , ne0) ;
119
+ const uint32_t i0 = i - i3 * prod_012. z - i2 * prod_01. z - i1 * ne0. z ;
84
120
85
- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
121
+ if (i0 >= ne0. z || i1 >= ne1. z || i2 >= ne2. z || i3 >= ne3) {
86
122
return ;
87
123
}
88
124
89
- const int i11 = i1 % ne11;
90
- const int i12 = i2 % ne12;
91
- const int i13 = i3 % ne13;
125
+ const int i11 = fastmodulo (i1, ne11) ;
126
+ const int i12 = fastmodulo (i2, ne12) ;
127
+ const int i13 = fastmodulo (i3, ne13) ;
92
128
93
129
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
94
130
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
97
133
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr ;
98
134
dst_t * dst_row = dst + i_dst;
99
135
100
- const int i10 = i0 % ne10;
136
+ const int i10 = fastmodulo (i0, ne10) ;
101
137
102
138
float result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
103
139
if constexpr (sizeof ...(src1_ptrs) > 0 ) {
@@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
170
206
// int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
171
207
// int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
172
208
173
- int64_t ne10 = cne1[0 ];
174
- int64_t ne11 = cne1[1 ];
175
- int64_t ne12 = cne1[2 ];
176
- int64_t ne13 = cne1[3 ];
177
-
178
209
size_t nb0 = cnb[0 ];
179
210
size_t nb1 = cnb[1 ];
180
211
size_t nb2 = cnb[2 ];
@@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
233
264
block_dims.y = std::min<unsigned int >(ne1, block_size / block_dims.x );
234
265
block_dims.z = std::min (std::min<unsigned int >(ne2 * ne3, block_size / block_dims.x / block_dims.y ), 64U );
235
266
236
- dim3 block_nums ((hne0 + block_dims.x - 1 ) / block_dims.x ,
237
- (ne1 + block_dims.y - 1 ) / block_dims.y ,
267
+ dim3 block_nums ((hne0 + block_dims.x - 1 ) / block_dims.x , (ne1 + block_dims.y - 1 ) / block_dims.y ,
238
268
(ne2 * ne3 + block_dims.z - 1 ) / block_dims.z );
239
269
270
+ const uint3 ne10 = init_fastdiv_values ((uint32_t ) cne1[0 ]);
271
+ const uint3 ne11 = init_fastdiv_values ((uint32_t ) cne1[1 ]);
272
+ const uint3 ne12 = init_fastdiv_values ((uint32_t ) cne1[2 ]);
273
+ const uint3 ne13 = init_fastdiv_values ((uint32_t ) cne1[3 ]);
274
+
240
275
if (block_nums.z > 65535 ) {
241
- int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
276
+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
277
+ const uint3 prod_012 = init_fastdiv_values ((uint32_t ) (ne0 * ne1 * ne2));
278
+ const uint3 prod_01 = init_fastdiv_values ((uint32_t ) (ne0 * ne1));
279
+ const uint3 ne0_fastdiv = init_fastdiv_values ((uint32_t ) ne0);
280
+ const uint3 ne1_fastdiv = init_fastdiv_values ((uint32_t ) ne1);
281
+ const uint3 ne2_fastdiv = init_fastdiv_values ((uint32_t ) ne2);
282
+
242
283
if constexpr (sizeof ...(I) > 0 ) {
243
- k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t >
244
- <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
245
- ne0, ne1, ne2, ne3,
246
- ne10, ne11, ne12, ne13,
247
- /* s0, */ s1, s2, s3,
248
- /* s00,*/ s01, s02, s03,
249
- /* s10,*/ s11, s12,s13,
250
- (const src1_t *) dst->src [I + 1 ]->data ...);
284
+ k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t ><<<block_num, block_size, 0 , stream>>> (
285
+ src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
286
+ ne12, ne13,
287
+ /* s0, */ s1, s2, s3,
288
+ /* s00,*/ s01, s02, s03,
289
+ /* s10,*/ s11, s12, s13, (const src1_t *) dst->src [I + 1 ]->data ...);
251
290
} else {
252
291
k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t >
253
- <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
254
- ne0, ne1, ne2, ne3,
255
- ne10, ne11, ne12, ne13,
256
- /* s0, */ s1, s2, s3,
257
- /* s00,*/ s01, s02, s03,
258
- /* s10,*/ s11, s12,s13);
292
+ <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
293
+ ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
294
+ /* s0, */ s1, s2, s3,
295
+ /* s00,*/ s01, s02, s03,
296
+ /* s10,*/ s11, s12, s13);
259
297
}
260
298
} else {
299
+ const uint3 ne3_fastdiv = init_fastdiv_values ((uint32_t ) ne3);
261
300
if constexpr (sizeof ...(I) > 0 ) {
262
- k_bin_bcast<bin_op, src0_t , src1_t , dst_t >
263
- <<<block_nums, block_dims, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
264
- ne0, ne1, ne2, ne3,
265
- ne10, ne11, ne12, ne13,
266
- /* s0, */ s1, s2, s3,
267
- /* s00,*/ s01, s02, s03,
268
- /* s10,*/ s11, s12,s13,
269
- (const src1_t *) dst->src [I + 1 ]->data ...);
301
+ k_bin_bcast<bin_op, src0_t , src1_t , dst_t ><<<block_nums, block_dims, 0 , stream>>> (
302
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
303
+ /* s0, */ s1, s2, s3,
304
+ /* s00,*/ s01, s02, s03,
305
+ /* s10,*/ s11, s12, s13, (const src1_t *) dst->src [I + 1 ]->data ...);
270
306
} else {
271
- k_bin_bcast<bin_op, src0_t , src1_t , dst_t >
272
- <<<block_nums, block_dims, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
273
- ne0, ne1, ne2, ne3,
274
- ne10, ne11, ne12, ne13,
275
- /* s0, */ s1, s2, s3,
276
- /* s00,*/ s01, s02, s03,
277
- /* s10,*/ s11, s12,s13);
307
+ k_bin_bcast<bin_op, src0_t , src1_t , dst_t ><<<block_nums, block_dims, 0 , stream>>> (
308
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
309
+ /* s0, */ s1, s2, s3,
310
+ /* s00,*/ s01, s02, s03,
311
+ /* s10,*/ s11, s12, s13);
278
312
}
279
313
}
280
314
}
0 commit comments