Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4739,6 +4739,7 @@ void ggml_compute_forward_get_rows(
//}
}

template<typename idx_t>
static void ggml_compute_forward_set_rows_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
Expand Down Expand Up @@ -4777,7 +4778,7 @@ static void ggml_compute_forward_set_rows_f32(
const int64_t i11 = i02%ne11;
const int64_t i10 = i;

const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);

GGML_ASSERT(i1 >= 0 && i1 < ne1);

Expand All @@ -4794,11 +4795,18 @@ void ggml_compute_forward_set_rows(
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_set_rows_f32(params, dst);
if (src1->type == GGML_TYPE_I64) {
ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
} else if (src1->type == GGML_TYPE_I32) {
ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
} else {
GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
}
} break;
default:
{
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3427,7 +3427,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_I64;
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
} break;
case GGML_OP_CPY:
{
Expand Down
72 changes: 40 additions & 32 deletions ggml/src/ggml-cuda/set-rows.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
typedef void (*set_rows_kernel_t)(const char * src, char * dst);

// Generic quantized set_rows kernel template
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static __global__ void k_set_rows_quant(
const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
Expand Down Expand Up @@ -45,9 +45,9 @@ static __global__ void k_set_rows_quant(
}

// Template dispatch function for quantized set_rows
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static void set_rows_cuda_quant(
const float * src0_d, const int64_t * src1_d, block_type * dst_d,
const float * src0_d, const idx_t * src1_d, block_type * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
Expand All @@ -64,15 +64,15 @@ static void set_rows_cuda_quant(
const int64_t s01 = nb01/sizeof(float);
const int64_t s02 = nb02/sizeof(float);
const int64_t s03 = nb03/sizeof(float);
const int64_t s10 = nb10/sizeof(int64_t);
const int64_t s11 = nb11/sizeof(int64_t);
const int64_t s12 = nb12/sizeof(int64_t);
const int64_t s10 = nb10/sizeof(idx_t);
const int64_t s11 = nb11/sizeof(idx_t);
const int64_t s12 = nb12/sizeof(idx_t);
const int64_t s1 = nb1;
const int64_t s2 = nb2;
const int64_t s3 = nb3;

if (ne_total > 0) {
k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
Expand All @@ -82,9 +82,9 @@ static void set_rows_cuda_quant(
}
}

template<typename src_t, typename dst_t>
template<typename src_t, typename idx_t, typename dst_t>
static __global__ void k_set_rows(
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
Expand Down Expand Up @@ -118,9 +118,9 @@ static __global__ void k_set_rows(
GGML_UNUSED(ne13);
}

template<typename src_t, typename dst_t>
template<typename src_t, typename idx_t, typename dst_t>
static void set_rows_cuda(
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
Expand All @@ -137,9 +137,9 @@ static void set_rows_cuda(
const int64_t s01 = nb01/sizeof(src_t);
const int64_t s02 = nb02/sizeof(src_t);
const int64_t s03 = nb03/sizeof(src_t);
const int64_t s10 = nb10/sizeof(int64_t);
const int64_t s11 = nb11/sizeof(int64_t);
const int64_t s12 = nb12/sizeof(int64_t);
const int64_t s10 = nb10/sizeof(idx_t);
const int64_t s11 = nb11/sizeof(idx_t);
const int64_t s12 = nb12/sizeof(idx_t);
const int64_t s1 = nb1/sizeof(dst_t);
const int64_t s2 = nb2/sizeof(dst_t);
const int64_t s3 = nb3/sizeof(dst_t);
Expand All @@ -155,23 +155,16 @@ static void set_rows_cuda(
}
}


void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_I64);
template<typename src_t, typename idx_t>
static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const src_t * src0_d = (const src_t *)src0->data;
const idx_t * src1_d = (const idx_t *)src1->data;

GGML_TENSOR_BINARY_OP_LOCALS

const float * src0_d = (const float *)src0->data;
const int64_t * src1_d = (const int64_t *)src1->data;

cudaStream_t stream = ctx.stream();



if (dst->type == GGML_TYPE_F32) {
set_rows_cuda(
src0_d, src1_d, (float*)dst->data,
Expand Down Expand Up @@ -203,7 +196,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
stream
);
} else if (dst->type == GGML_TYPE_Q4_0) {
set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
src0_d, src1_d, (block_q4_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
Expand All @@ -213,7 +206,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
stream
);
} else if (dst->type == GGML_TYPE_Q4_1) {
set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
src0_d, src1_d, (block_q4_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
Expand All @@ -223,7 +216,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
stream
);
} else if (dst->type == GGML_TYPE_Q5_0) {
set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
src0_d, src1_d, (block_q5_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
Expand All @@ -233,7 +226,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
stream
);
} else if (dst->type == GGML_TYPE_Q5_1) {
set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
src0_d, src1_d, (block_q5_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
Expand All @@ -243,7 +236,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
stream
);
} else if (dst->type == GGML_TYPE_Q8_0) {
set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
src0_d, src1_d, (block_q8_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
Expand All @@ -253,7 +246,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
stream
);
} else if (dst->type == GGML_TYPE_IQ4_NL) {
set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
src0_d, src1_d, (block_iq4_nl*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
Expand All @@ -266,3 +259,18 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
}
}


void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);

if (src1->type == GGML_TYPE_I64) {
set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
} else {
set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
}
}
4 changes: 2 additions & 2 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_librar
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tdst) {
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
char base[256];
char name[256];

snprintf(base, 256, "kernel_set_rows_%s", ggml_type_name(tdst));
snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->type);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);

const int32_t nk0 = ne0/ggml_blck_size(op->type);

Expand Down
41 changes: 25 additions & 16 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -7743,7 +7743,7 @@ kernel void kernel_get_rows_i32(
}
}

template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_set_rows_q32(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
Expand All @@ -7764,7 +7764,7 @@ kernel void kernel_set_rows_q32(
}

const int32_t i10 = i01;
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];

device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
Expand All @@ -7774,7 +7774,7 @@ kernel void kernel_set_rows_q32(
}
}

template<typename T>
template<typename T, typename TI>
kernel void kernel_set_rows_f(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
Expand All @@ -7795,7 +7795,7 @@ kernel void kernel_set_rows_f(
}

const int32_t i10 = i01;
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];

device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
Expand Down Expand Up @@ -8218,22 +8218,31 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
// set rows
//

typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;

template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
#endif

typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;

template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;

template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;

//
// matrix-matrix multiplication
Expand Down
Loading