4
4
typedef void (*set_rows_kernel_t )(const char * src, char * dst);
5
5
6
6
// Generic quantized set_rows kernel template
7
- template <typename block_type, int qk, void (*quantize_func)(const float *, block_type*)>
7
+ template <typename idx_t , typename block_type, int qk, void (*quantize_func)(const float *, block_type*)>
8
8
static __global__ void k_set_rows_quant (
9
- const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
9
+ const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
10
10
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
11
11
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
12
12
const int64_t s01, const int64_t s02, const int64_t s03,
@@ -45,9 +45,9 @@ static __global__ void k_set_rows_quant(
45
45
}
46
46
47
47
// Template dispatch function for quantized set_rows
48
- template <typename block_type, int qk, void (*quantize_func)(const float *, block_type*)>
48
+ template <typename idx_t , typename block_type, int qk, void (*quantize_func)(const float *, block_type*)>
49
49
static void set_rows_cuda_quant (
50
- const float * src0_d, const int64_t * src1_d, block_type * dst_d,
50
+ const float * src0_d, const idx_t * src1_d, block_type * dst_d,
51
51
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
52
52
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
53
53
const size_t nb01, const size_t nb02, const size_t nb03,
@@ -64,15 +64,15 @@ static void set_rows_cuda_quant(
64
64
const int64_t s01 = nb01/sizeof (float );
65
65
const int64_t s02 = nb02/sizeof (float );
66
66
const int64_t s03 = nb03/sizeof (float );
67
- const int64_t s10 = nb10/sizeof (int64_t );
68
- const int64_t s11 = nb11/sizeof (int64_t );
69
- const int64_t s12 = nb12/sizeof (int64_t );
67
+ const int64_t s10 = nb10/sizeof (idx_t );
68
+ const int64_t s11 = nb11/sizeof (idx_t );
69
+ const int64_t s12 = nb12/sizeof (idx_t );
70
70
const int64_t s1 = nb1;
71
71
const int64_t s2 = nb2;
72
72
const int64_t s3 = nb3;
73
73
74
74
if (ne_total > 0 ) {
75
- k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0 , stream>>> (
75
+ k_set_rows_quant<idx_t , block_type, qk, quantize_func><<<grid_size, block_size, 0 , stream>>> (
76
76
src0_d, src1_d, dst_d,
77
77
ne00, ne01, ne02, ne03,
78
78
ne10, ne11, ne12, ne13,
@@ -82,9 +82,9 @@ static void set_rows_cuda_quant(
82
82
}
83
83
}
84
84
85
- template <typename src_t , typename dst_t >
85
+ template <typename src_t , typename idx_t , typename dst_t >
86
86
static __global__ void k_set_rows (
87
- const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
87
+ const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
88
88
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
89
89
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
90
90
const int64_t s01, const int64_t s02, const int64_t s03,
@@ -118,9 +118,9 @@ static __global__ void k_set_rows(
118
118
GGML_UNUSED (ne13);
119
119
}
120
120
121
- template <typename src_t , typename dst_t >
121
+ template <typename src_t , typename idx_t , typename dst_t >
122
122
static void set_rows_cuda (
123
- const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
123
+ const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
124
124
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
125
125
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
126
126
const size_t nb01, const size_t nb02, const size_t nb03,
@@ -137,9 +137,9 @@ static void set_rows_cuda(
137
137
const int64_t s01 = nb01/sizeof (src_t );
138
138
const int64_t s02 = nb02/sizeof (src_t );
139
139
const int64_t s03 = nb03/sizeof (src_t );
140
- const int64_t s10 = nb10/sizeof (int64_t );
141
- const int64_t s11 = nb11/sizeof (int64_t );
142
- const int64_t s12 = nb12/sizeof (int64_t );
140
+ const int64_t s10 = nb10/sizeof (idx_t );
141
+ const int64_t s11 = nb11/sizeof (idx_t );
142
+ const int64_t s12 = nb12/sizeof (idx_t );
143
143
const int64_t s1 = nb1/sizeof (dst_t );
144
144
const int64_t s2 = nb2/sizeof (dst_t );
145
145
const int64_t s3 = nb3/sizeof (dst_t );
@@ -155,23 +155,16 @@ static void set_rows_cuda(
155
155
}
156
156
}
157
157
158
-
159
- void ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160
- const ggml_tensor * src0 = dst->src [0 ];
161
- const ggml_tensor * src1 = dst->src [1 ];
162
-
163
- GGML_ASSERT (src0->type == GGML_TYPE_F32);
164
- GGML_ASSERT (src1->type == GGML_TYPE_I64);
158
+ template <typename src_t , typename idx_t >
159
+ static void set_rows_cuda (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
160
+ const src_t * src0_d = (const src_t *)src0->data ;
161
+ const idx_t * src1_d = (const idx_t *)src1->data ;
165
162
166
163
GGML_TENSOR_BINARY_OP_LOCALS
167
164
168
- const float * src0_d = (const float *)src0->data ;
169
- const int64_t * src1_d = (const int64_t *)src1->data ;
170
-
171
165
cudaStream_t stream = ctx.stream ();
172
166
173
167
174
-
175
168
if (dst->type == GGML_TYPE_F32) {
176
169
set_rows_cuda (
177
170
src0_d, src1_d, (float *)dst->data ,
@@ -203,7 +196,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
203
196
stream
204
197
);
205
198
} else if (dst->type == GGML_TYPE_Q4_0) {
206
- set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
199
+ set_rows_cuda_quant<idx_t , block_q4_0, QK4_0, quantize_f32_q4_0_block>(
207
200
src0_d, src1_d, (block_q4_0*)dst->data ,
208
201
ne00, ne01, ne02, ne03,
209
202
ne10, ne11, ne12, ne13,
@@ -213,7 +206,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213
206
stream
214
207
);
215
208
} else if (dst->type == GGML_TYPE_Q4_1) {
216
- set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
209
+ set_rows_cuda_quant<idx_t , block_q4_1, QK4_1, quantize_f32_q4_1_block>(
217
210
src0_d, src1_d, (block_q4_1*)dst->data ,
218
211
ne00, ne01, ne02, ne03,
219
212
ne10, ne11, ne12, ne13,
@@ -223,7 +216,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
223
216
stream
224
217
);
225
218
} else if (dst->type == GGML_TYPE_Q5_0) {
226
- set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
219
+ set_rows_cuda_quant<idx_t , block_q5_0, QK5_0, quantize_f32_q5_0_block>(
227
220
src0_d, src1_d, (block_q5_0*)dst->data ,
228
221
ne00, ne01, ne02, ne03,
229
222
ne10, ne11, ne12, ne13,
@@ -233,7 +226,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
233
226
stream
234
227
);
235
228
} else if (dst->type == GGML_TYPE_Q5_1) {
236
- set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
229
+ set_rows_cuda_quant<idx_t , block_q5_1, QK5_1, quantize_f32_q5_1_block>(
237
230
src0_d, src1_d, (block_q5_1*)dst->data ,
238
231
ne00, ne01, ne02, ne03,
239
232
ne10, ne11, ne12, ne13,
@@ -243,7 +236,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
243
236
stream
244
237
);
245
238
} else if (dst->type == GGML_TYPE_Q8_0) {
246
- set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
239
+ set_rows_cuda_quant<idx_t , block_q8_0, QK8_0, quantize_f32_q8_0_block>(
247
240
src0_d, src1_d, (block_q8_0*)dst->data ,
248
241
ne00, ne01, ne02, ne03,
249
242
ne10, ne11, ne12, ne13,
@@ -253,7 +246,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
253
246
stream
254
247
);
255
248
} else if (dst->type == GGML_TYPE_IQ4_NL) {
256
- set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
249
+ set_rows_cuda_quant<idx_t , block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
257
250
src0_d, src1_d, (block_iq4_nl*)dst->data ,
258
251
ne00, ne01, ne02, ne03,
259
252
ne10, ne11, ne12, ne13,
@@ -266,3 +259,18 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
266
259
GGML_ABORT (" unsupported type %s" , ggml_type_name (dst->type ));
267
260
}
268
261
}
262
+
263
+
264
+ void ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
265
+ const ggml_tensor * src0 = dst->src [0 ];
266
+ const ggml_tensor * src1 = dst->src [1 ];
267
+
268
+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
269
+ GGML_ASSERT (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
270
+
271
+ if (src1->type == GGML_TYPE_I64) {
272
+ set_rows_cuda<float , int64_t >(ctx, src0, src1, dst);
273
+ } else {
274
+ set_rows_cuda<float , int32_t >(ctx, src0, src1, dst);
275
+ }
276
+ }
0 commit comments