Skip to content

Commit 00681df

Browse files
authored
CUDA: Add fastdiv to k_bin_bcast*, giving 1-3% E2E performance (#15872)
* Add fastdiv and fastmodulo to k_bin_bcast kernel * Address review comments * `prod_` instead of `prod` suffix * Add test case for `k_bin_bcast_unravel` in CUDA backend
1 parent 4f65885 commit 00681df

File tree

2 files changed

+112
-75
lines changed

2 files changed

+112
-75
lines changed

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 109 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
2323
return a / b;
2424
}
2525

26-
27-
28-
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
29-
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
30-
const int ne0, const int ne1, const int ne2, const int ne3,
31-
const int ne10, const int ne11, const int ne12, const int ne13,
32-
/*int s0, */ const int s1, const int s2, const int s3,
33-
/*int s00,*/ const int s01, const int s02, const int s03,
34-
/*int s10,*/ const int s11, const int s12, const int s13,
35-
src1_ptrs... src1s) {
36-
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
37-
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
38-
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
39-
const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
40-
41-
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
26+
template <float (*bin_op)(const float, const float),
27+
typename src0_t,
28+
typename src1_t,
29+
typename dst_t,
30+
typename... src1_ptrs>
31+
static __global__ void k_bin_bcast(const src0_t * src0,
32+
const src1_t * src1,
33+
dst_t * dst,
34+
const int ne0,
35+
const int ne1,
36+
const int ne2,
37+
const uint3 ne3,
38+
const uint3 ne10,
39+
const uint3 ne11,
40+
const uint3 ne12,
41+
const uint3 ne13,
42+
/*int s0, */ const int s1,
43+
const int s2,
44+
const int s3,
45+
/*int s00,*/ const int s01,
46+
const int s02,
47+
const int s03,
48+
/*int s10,*/ const int s11,
49+
const int s12,
50+
const int s13,
51+
src1_ptrs... src1s) {
52+
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
53+
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
54+
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
55+
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
56+
57+
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
4258
return;
4359
}
4460

45-
const int i11 = i1 % ne11;
46-
const int i12 = i2 % ne12;
47-
const int i13 = i3 % ne13;
61+
const uint32_t i11 = fastmodulo(i1, ne11);
62+
const uint32_t i12 = fastmodulo(i2, ne12);
63+
const uint32_t i13 = fastmodulo(i3, ne13);
4864

4965
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
5066
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
5369
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
5470
dst_t * dst_row = dst + i_dst;
5571

56-
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
57-
const int i10 = i0 % ne10;
72+
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
73+
const uint32_t i10 = fastmodulo(i0, ne10);
5874

5975
float result = src0_row ? (float) src0_row[i0] : 0.0f;
6076
if constexpr (sizeof...(src1_ptrs) > 0) {
@@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
6783
}
6884
}
6985

70-
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
71-
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
72-
const int ne0, const int ne1, const int ne2,const int ne3,
73-
const int ne10, const int ne11, const int ne12, const int ne13,
74-
/*int s0, */ const int s1, const int s2, const int s3,
75-
/*int s00,*/ const int s01, const int s02, const int s03,
76-
/*int s10,*/ const int s11, const int s12, const int s13,
77-
src1_ptrs ... src1s) {
86+
template <float (*bin_op)(const float, const float),
87+
typename src0_t,
88+
typename src1_t,
89+
typename dst_t,
90+
typename... src1_ptrs>
91+
static __global__ void k_bin_bcast_unravel(const src0_t * src0,
92+
const src1_t * src1,
93+
dst_t * dst,
94+
const uint3 ne0,
95+
const uint3 ne1,
96+
const uint3 ne2,
97+
const uint32_t ne3,
98+
const uint3 prod_012,
99+
const uint3 prod_01,
100+
const uint3 ne10,
101+
const uint3 ne11,
102+
const uint3 ne12,
103+
const uint3 ne13,
104+
/*int s0, */ const int s1,
105+
const int s2,
106+
const int s3,
107+
/*int s00,*/ const int s01,
108+
const int s02,
109+
const int s03,
110+
/*int s10,*/ const int s11,
111+
const int s12,
112+
const int s13,
113+
src1_ptrs... src1s) {
78114
const int i = blockDim.x*blockIdx.x + threadIdx.x;
79115

80-
const int i3 = i/(ne2*ne1*ne0);
81-
const int i2 = (i/(ne1*ne0)) % ne2;
82-
const int i1 = (i/ne0) % ne1;
83-
const int i0 = i % ne0;
116+
const uint32_t i3 = fastdiv(i, prod_012);
117+
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
118+
const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
119+
const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
84120

85-
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
121+
if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
86122
return;
87123
}
88124

89-
const int i11 = i1 % ne11;
90-
const int i12 = i2 % ne12;
91-
const int i13 = i3 % ne13;
125+
const int i11 = fastmodulo(i1, ne11);
126+
const int i12 = fastmodulo(i2, ne12);
127+
const int i13 = fastmodulo(i3, ne13);
92128

93129
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
94130
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
97133
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
98134
dst_t * dst_row = dst + i_dst;
99135

100-
const int i10 = i0 % ne10;
136+
const int i10 = fastmodulo(i0, ne10);
101137

102138
float result = src0_row ? (float) src0_row[i0] : 0.0f;
103139
if constexpr (sizeof...(src1_ptrs) > 0) {
@@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
170206
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
171207
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
172208

173-
int64_t ne10 = cne1[0];
174-
int64_t ne11 = cne1[1];
175-
int64_t ne12 = cne1[2];
176-
int64_t ne13 = cne1[3];
177-
178209
size_t nb0 = cnb[0];
179210
size_t nb1 = cnb[1];
180211
size_t nb2 = cnb[2];
@@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
233264
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
234265
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
235266

236-
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
237-
(ne1 + block_dims.y - 1) / block_dims.y,
267+
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
238268
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
239269

270+
const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
271+
const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
272+
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
273+
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
274+
240275
if (block_nums.z > 65535) {
241-
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
276+
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
277+
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
278+
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
279+
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
280+
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
281+
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
282+
242283
if constexpr (sizeof...(I) > 0) {
243-
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
244-
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
245-
ne0, ne1, ne2, ne3,
246-
ne10, ne11, ne12, ne13,
247-
/* s0, */ s1, s2, s3,
248-
/* s00,*/ s01, s02, s03,
249-
/* s10,*/ s11, s12,s13,
250-
(const src1_t *) dst->src[I + 1]->data...);
284+
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
285+
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
286+
ne12, ne13,
287+
/* s0, */ s1, s2, s3,
288+
/* s00,*/ s01, s02, s03,
289+
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
251290
} else {
252291
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
253-
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
254-
ne0, ne1, ne2, ne3,
255-
ne10, ne11, ne12, ne13,
256-
/* s0, */ s1, s2, s3,
257-
/* s00,*/ s01, s02, s03,
258-
/* s10,*/ s11, s12,s13);
292+
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
293+
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
294+
/* s0, */ s1, s2, s3,
295+
/* s00,*/ s01, s02, s03,
296+
/* s10,*/ s11, s12, s13);
259297
}
260298
} else {
299+
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
261300
if constexpr (sizeof...(I) > 0) {
262-
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
263-
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
264-
ne0, ne1, ne2, ne3,
265-
ne10, ne11, ne12, ne13,
266-
/* s0, */ s1, s2, s3,
267-
/* s00,*/ s01, s02, s03,
268-
/* s10,*/ s11, s12,s13,
269-
(const src1_t *) dst->src[I + 1]->data...);
301+
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
302+
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
303+
/* s0, */ s1, s2, s3,
304+
/* s00,*/ s01, s02, s03,
305+
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
270306
} else {
271-
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
272-
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
273-
ne0, ne1, ne2, ne3,
274-
ne10, ne11, ne12, ne13,
275-
/* s0, */ s1, s2, s3,
276-
/* s00,*/ s01, s02, s03,
277-
/* s10,*/ s11, s12,s13);
307+
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
308+
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
309+
/* s0, */ s1, s2, s3,
310+
/* s00,*/ s01, s02, s03,
311+
/* s10,*/ s11, s12, s13);
278312
}
279313
}
280314
}

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6050,6 +6050,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
60506050
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
60516051
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
60526052

6053+
// test case for k_bin_bcast_unravel in CUDA backend
6054+
add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});
6055+
60536056
// stable diffusion
60546057
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
60556058
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});

0 commit comments

Comments
 (0)