Skip to content

Commit 3ecb2f6

Browse files
CISCggerganovjeffbolznv
authored
ggml : implement set_rows with i32 index (#16159)
* implement set_rows with i32 index * template fix * test quantized path warnings-- * Apply suggestions from code review Co-authored-by: Georgi Gerganov <[email protected]> * forgotten name change * deduplicate cuda/sycl and test-fix * indent++ * vulkan: support set_rows with i32 index type (#16162) * disable i32 index for webgpu for now --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: Jeff Bolz <[email protected]>
1 parent 432cf43 commit 3ecb2f6

File tree

17 files changed

+300
-134
lines changed

17 files changed

+300
-134
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4739,6 +4739,7 @@ void ggml_compute_forward_get_rows(
47394739
//}
47404740
}
47414741

4742+
template<typename idx_t>
47424743
static void ggml_compute_forward_set_rows_f32(
47434744
const ggml_compute_params * params,
47444745
ggml_tensor * dst) {
@@ -4777,7 +4778,7 @@ static void ggml_compute_forward_set_rows_f32(
47774778
const int64_t i11 = i02%ne11;
47784779
const int64_t i10 = i;
47794780

4780-
const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4781+
const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
47814782

47824783
GGML_ASSERT(i1 >= 0 && i1 < ne1);
47834784

@@ -4794,11 +4795,18 @@ void ggml_compute_forward_set_rows(
47944795
ggml_tensor * dst) {
47954796

47964797
const ggml_tensor * src0 = dst->src[0];
4798+
const ggml_tensor * src1 = dst->src[1];
47974799

47984800
switch (src0->type) {
47994801
case GGML_TYPE_F32:
48004802
{
4801-
ggml_compute_forward_set_rows_f32(params, dst);
4803+
if (src1->type == GGML_TYPE_I64) {
4804+
ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4805+
} else if (src1->type == GGML_TYPE_I32) {
4806+
ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4807+
} else {
4808+
GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
4809+
}
48024810
} break;
48034811
default:
48044812
{

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3427,7 +3427,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
34273427
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
34283428
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
34293429
op->src[0]->type == GGML_TYPE_F32 &&
3430-
op->src[1]->type == GGML_TYPE_I64;
3430+
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
34313431
} break;
34323432
case GGML_OP_CPY:
34333433
{

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
55

66
// Generic quantized set_rows kernel template
7-
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
7+
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
88
static __global__ void k_set_rows_quant(
9-
const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
9+
const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
1010
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
1111
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
1212
const int64_t s01, const int64_t s02, const int64_t s03,
@@ -45,9 +45,9 @@ static __global__ void k_set_rows_quant(
4545
}
4646

4747
// Template dispatch function for quantized set_rows
48-
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
48+
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
4949
static void set_rows_cuda_quant(
50-
const float * src0_d, const int64_t * src1_d, block_type * dst_d,
50+
const float * src0_d, const idx_t * src1_d, block_type * dst_d,
5151
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
5252
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
5353
const size_t nb01, const size_t nb02, const size_t nb03,
@@ -64,15 +64,15 @@ static void set_rows_cuda_quant(
6464
const int64_t s01 = nb01/sizeof(float);
6565
const int64_t s02 = nb02/sizeof(float);
6666
const int64_t s03 = nb03/sizeof(float);
67-
const int64_t s10 = nb10/sizeof(int64_t);
68-
const int64_t s11 = nb11/sizeof(int64_t);
69-
const int64_t s12 = nb12/sizeof(int64_t);
67+
const int64_t s10 = nb10/sizeof(idx_t);
68+
const int64_t s11 = nb11/sizeof(idx_t);
69+
const int64_t s12 = nb12/sizeof(idx_t);
7070
const int64_t s1 = nb1;
7171
const int64_t s2 = nb2;
7272
const int64_t s3 = nb3;
7373

7474
if (ne_total > 0) {
75-
k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
75+
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
7676
src0_d, src1_d, dst_d,
7777
ne00, ne01, ne02, ne03,
7878
ne10, ne11, ne12, ne13,
@@ -82,9 +82,9 @@ static void set_rows_cuda_quant(
8282
}
8383
}
8484

85-
template<typename src_t, typename dst_t>
85+
template<typename src_t, typename idx_t, typename dst_t>
8686
static __global__ void k_set_rows(
87-
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
87+
const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
8888
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
8989
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
9090
const int64_t s01, const int64_t s02, const int64_t s03,
@@ -118,9 +118,9 @@ static __global__ void k_set_rows(
118118
GGML_UNUSED(ne13);
119119
}
120120

121-
template<typename src_t, typename dst_t>
121+
template<typename src_t, typename idx_t, typename dst_t>
122122
static void set_rows_cuda(
123-
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
123+
const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
124124
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
125125
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
126126
const size_t nb01, const size_t nb02, const size_t nb03,
@@ -137,9 +137,9 @@ static void set_rows_cuda(
137137
const int64_t s01 = nb01/sizeof(src_t);
138138
const int64_t s02 = nb02/sizeof(src_t);
139139
const int64_t s03 = nb03/sizeof(src_t);
140-
const int64_t s10 = nb10/sizeof(int64_t);
141-
const int64_t s11 = nb11/sizeof(int64_t);
142-
const int64_t s12 = nb12/sizeof(int64_t);
140+
const int64_t s10 = nb10/sizeof(idx_t);
141+
const int64_t s11 = nb11/sizeof(idx_t);
142+
const int64_t s12 = nb12/sizeof(idx_t);
143143
const int64_t s1 = nb1/sizeof(dst_t);
144144
const int64_t s2 = nb2/sizeof(dst_t);
145145
const int64_t s3 = nb3/sizeof(dst_t);
@@ -155,23 +155,16 @@ static void set_rows_cuda(
155155
}
156156
}
157157

158-
159-
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160-
const ggml_tensor * src0 = dst->src[0];
161-
const ggml_tensor * src1 = dst->src[1];
162-
163-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
164-
GGML_ASSERT(src1->type == GGML_TYPE_I64);
158+
template<typename src_t, typename idx_t>
159+
static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
160+
const src_t * src0_d = (const src_t *)src0->data;
161+
const idx_t * src1_d = (const idx_t *)src1->data;
165162

166163
GGML_TENSOR_BINARY_OP_LOCALS
167164

168-
const float * src0_d = (const float *)src0->data;
169-
const int64_t * src1_d = (const int64_t *)src1->data;
170-
171165
cudaStream_t stream = ctx.stream();
172166

173167

174-
175168
if (dst->type == GGML_TYPE_F32) {
176169
set_rows_cuda(
177170
src0_d, src1_d, (float*)dst->data,
@@ -203,7 +196,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
203196
stream
204197
);
205198
} else if (dst->type == GGML_TYPE_Q4_0) {
206-
set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
199+
set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
207200
src0_d, src1_d, (block_q4_0*)dst->data,
208201
ne00, ne01, ne02, ne03,
209202
ne10, ne11, ne12, ne13,
@@ -213,7 +206,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213206
stream
214207
);
215208
} else if (dst->type == GGML_TYPE_Q4_1) {
216-
set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
209+
set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
217210
src0_d, src1_d, (block_q4_1*)dst->data,
218211
ne00, ne01, ne02, ne03,
219212
ne10, ne11, ne12, ne13,
@@ -223,7 +216,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
223216
stream
224217
);
225218
} else if (dst->type == GGML_TYPE_Q5_0) {
226-
set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
219+
set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
227220
src0_d, src1_d, (block_q5_0*)dst->data,
228221
ne00, ne01, ne02, ne03,
229222
ne10, ne11, ne12, ne13,
@@ -233,7 +226,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
233226
stream
234227
);
235228
} else if (dst->type == GGML_TYPE_Q5_1) {
236-
set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
229+
set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
237230
src0_d, src1_d, (block_q5_1*)dst->data,
238231
ne00, ne01, ne02, ne03,
239232
ne10, ne11, ne12, ne13,
@@ -243,7 +236,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
243236
stream
244237
);
245238
} else if (dst->type == GGML_TYPE_Q8_0) {
246-
set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
239+
set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
247240
src0_d, src1_d, (block_q8_0*)dst->data,
248241
ne00, ne01, ne02, ne03,
249242
ne10, ne11, ne12, ne13,
@@ -253,7 +246,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
253246
stream
254247
);
255248
} else if (dst->type == GGML_TYPE_IQ4_NL) {
256-
set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
249+
set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
257250
src0_d, src1_d, (block_iq4_nl*)dst->data,
258251
ne00, ne01, ne02, ne03,
259252
ne10, ne11, ne12, ne13,
@@ -266,3 +259,18 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
266259
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
267260
}
268261
}
262+
263+
264+
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
265+
const ggml_tensor * src0 = dst->src[0];
266+
const ggml_tensor * src1 = dst->src[1];
267+
268+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
269+
GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
270+
271+
if (src1->type == GGML_TYPE_I64) {
272+
set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
273+
} else {
274+
set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
275+
}
276+
}

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_librar
142142
return res;
143143
}
144144

145-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tdst) {
145+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
146146
char base[256];
147147
char name[256];
148148

149-
snprintf(base, 256, "kernel_set_rows_%s", ggml_type_name(tdst));
149+
snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
150150
snprintf(name, 256, "%s", base);
151151

152152
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_me
105105
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
106106
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
107107
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
108-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tdst);
108+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
109109
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
110110
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
111111
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
892892
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
893893
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
894894

895-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->type);
895+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
896896

897897
const int32_t nk0 = ne0/ggml_blck_size(op->type);
898898

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7743,7 +7743,7 @@ kernel void kernel_get_rows_i32(
77437743
}
77447744
}
77457745

7746-
template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
7746+
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
77477747
kernel void kernel_set_rows_q32(
77487748
constant ggml_metal_kargs_set_rows & args,
77497749
device const void * src0,
@@ -7764,7 +7764,7 @@ kernel void kernel_set_rows_q32(
77647764
}
77657765

77667766
const int32_t i10 = i01;
7767-
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
7767+
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
77687768

77697769
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
77707770
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
@@ -7774,7 +7774,7 @@ kernel void kernel_set_rows_q32(
77747774
}
77757775
}
77767776

7777-
template<typename T>
7777+
template<typename T, typename TI>
77787778
kernel void kernel_set_rows_f(
77797779
constant ggml_metal_kargs_set_rows & args,
77807780
device const void * src0,
@@ -7795,7 +7795,7 @@ kernel void kernel_set_rows_f(
77957795
}
77967796

77977797
const int32_t i10 = i01;
7798-
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
7798+
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
77997799

78007800
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
78017801
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
@@ -8218,22 +8218,31 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
82188218
// set rows
82198219
//
82208220

8221-
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
8221+
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
82228222

8223-
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
8224-
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
8223+
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
8224+
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
8225+
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
8226+
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
82258227
#if defined(GGML_METAL_HAS_BF16)
8226-
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
8228+
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
8229+
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
82278230
#endif
82288231

8229-
typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
8230-
8231-
template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
8232-
template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
8233-
template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
8234-
template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
8235-
template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
8236-
template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
8232+
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
8233+
8234+
template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
8235+
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
8236+
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
8237+
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
8238+
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
8239+
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
8240+
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
8241+
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
8242+
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
8243+
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
8244+
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
8245+
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
82378246

82388247
//
82398248
// matrix-matrix multiplication

0 commit comments

Comments
 (0)