Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,13 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128

// vaddvq_s16
// vpaddq_s16
// vpaddq_s32
// vaddvq_s32
// vaddvq_f32
// vmaxvq_f32
// vcvtnq_s32_f32
// vzip1_u8
// vzip2_u8

inline static int32_t vaddvq_s16(int16x8_t v) {
return
Expand All @@ -291,6 +294,12 @@ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
return vcombine_s16(a0, b0);
}

inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
return vcombine_s32(a0, b0);
}

inline static int32_t vaddvq_s32(int32x4_t v) {
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
}
Expand All @@ -316,6 +325,28 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
return res;
}

inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
uint8x8_t res;

res[0] = a[0]; res[1] = b[0];
res[2] = a[1]; res[3] = b[1];
res[4] = a[2]; res[5] = b[2];
res[6] = a[3]; res[7] = b[3];

return res;
}

inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
uint8x8_t res;

res[0] = a[4]; res[1] = b[4];
res[2] = a[5]; res[3] = b[5];
res[4] = a[6]; res[5] = b[6];
res[6] = a[7]; res[7] = b[7];

return res;
}

// vld1q_s16_x2
// vld1q_u8_x2
// vld1q_u8_x4
Expand Down Expand Up @@ -7554,9 +7585,9 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest

const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;

int8x16x4_t q2u;
int8x16x4_t q2s;
int8x16x4_t q8b;
ggml_int8x16x4_t q2u;
ggml_int8x16x4_t q2s;
ggml_int8x16x4_t q8b;

int32x4x4_t scales32;

Expand All @@ -7578,7 +7609,7 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
int32x4_t sumi = vdupq_n_s32(0);
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
q8b = vld1q_s8_x4(q8); q8 += 64;
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
Expand Down