Skip to content

Commit a8b0089

Browse files
committed
ggml : remove SVE paths
1 parent d9e0e7c commit a8b0089

File tree

4 files changed

+126
-891
lines changed

4 files changed

+126
-891
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 68 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -8646,41 +8646,7 @@ static void ggml_compute_forward_ssm_scan_f32(
86468646
const int ii = i1 + h*nr;
86478647
const float x_dt = x[ii] * dt_soft_plus;
86488648
float sumf = 0.0f;
8649-
#if defined(GGML_SIMD)
8650-
#if defined(__ARM_FEATURE_SVE)
8651-
const int ggml_f32_epr = svcntw();
8652-
const int ggml_f32_step = 1 * ggml_f32_epr;
8653-
8654-
const int np = (nc & ~(ggml_f32_step - 1));
8655-
8656-
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8657-
8658-
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8659-
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8660-
8661-
for (int i = 0; i < np; i += ggml_f32_step) {
8662-
// TODO: maybe unroll more?
8663-
for (int j = 0; j < 1; j++) {
8664-
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8665-
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
8666-
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
8667-
8668-
t0 = GGML_F32_VEC_MUL(t0, adA);
8669-
t1 = GGML_F32_VEC_MUL(t1, axdt);
8670-
8671-
t0 = GGML_F32_VEC_ADD(t0, t1);
8672-
8673-
sum = GGML_F32_VEC_FMA(sum, t0, t2);
8674-
8675-
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8676-
}
8677-
}
8678-
8679-
sumf = GGML_F32xt_REDUCE_ONE(sum);
8680-
#elif defined(__riscv_v_intrinsic)
8681-
// todo: RVV implementation
8682-
const int np = 0;
8683-
#else
8649+
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
86848650
const int np = (nc & ~(GGML_F32_STEP - 1));
86858651

86868652
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -8711,7 +8677,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87118677

87128678
// reduce sum0..sum3 to sum0
87138679
GGML_F32_VEC_REDUCE(sumf, sum);
8714-
#endif
87158680
#else
87168681
const int np = 0;
87178682
#endif
@@ -8741,30 +8706,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87418706
for (int i1 = 0; i1 < nr; ++i1) {
87428707
const int ii = i1 + h*nr;
87438708
const float x_dt = x[ii] * dt_soft_plus;
8744-
#if defined(__ARM_FEATURE_SVE)
8745-
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8746-
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8747-
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8748-
8749-
// d_state
8750-
// TODO: what happens when (d_state % svcntw()) != 0?
8751-
for (int64_t k = 0; k < nc; k += svcntw()) {
8752-
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8753-
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
8754-
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
8755-
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8756-
8757-
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8758-
t1 = exp_ps_sve(svptrue_b32(), t1);
8759-
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8760-
8761-
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8762-
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8763-
8764-
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8765-
}
8766-
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8767-
#else
87688709
float sumf = 0.0f;
87698710
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
87708711
// and also because expf is used within the loop.
@@ -8779,7 +8720,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87798720
s[i] = state;
87808721
}
87818722
y[ii] = sumf;
8782-
#endif
87838723
}
87848724
}
87858725
}
@@ -9231,14 +9171,6 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
92319171
#define GGML_F32X_MUL GGML_F32x16_MUL
92329172
#define GGML_F32X_FMA GGML_F32x16_FMA
92339173
#define WKV_VECTOR_SIZE 16
9234-
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9235-
#define GGML_F32X GGML_F32xt
9236-
#define GGML_F32X_SET1 GGML_F32xt_SET1
9237-
#define GGML_F32X_LOAD GGML_F32xt_LOAD
9238-
#define GGML_F32X_STORE GGML_F32xt_STORE
9239-
#define GGML_F32X_MUL GGML_F32xt_MUL
9240-
#define GGML_F32X_FMA GGML_F32xt_FMA
9241-
#define WKV_VECTOR_SIZE 8
92429174
#elif defined(__ARM_NEON) && defined(__aarch64__)
92439175
#define GGML_F32X GGML_F32x4
92449176
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -9251,11 +9183,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
92519183

92529184
#ifdef WKV_VECTOR_SIZE
92539185
int wkv_vector_size;
9254-
#if defined(__ARM_FEATURE_SVE)
9255-
wkv_vector_size = svcntw();
9256-
#else
9257-
wkv_vector_size = WKV_VECTOR_SIZE;
9258-
#endif
9186+
wkv_vector_size = WKV_VECTOR_SIZE;
92599187
const int64_t vec_count = head_size / wkv_vector_size;
92609188

92619189
for (int64_t t = 0; t < T; t++) {
@@ -9447,14 +9375,6 @@ static void ggml_compute_forward_gla_f32(
94479375
#define GGML_F32X_MUL GGML_F32x16_MUL
94489376
#define GGML_F32X_FMA GGML_F32x16_FMA
94499377
#define GLA_VECTOR_SIZE 16
9450-
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9451-
#define GGML_F32X GGML_F32xt
9452-
#define GGML_F32X_SET1 GGML_F32xt_SET1
9453-
#define GGML_F32X_LOAD GGML_F32xt_LOAD
9454-
#define GGML_F32X_STORE GGML_F32xt_STORE
9455-
#define GGML_F32X_MUL GGML_F32xt_MUL
9456-
#define GGML_F32X_FMA GGML_F32xt_FMA
9457-
#define GLA_VECTOR_SIZE 8
94589378
#elif defined(__ARM_NEON) && defined(__aarch64__)
94599379
#define GGML_F32X GGML_F32x4
94609380
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -9467,11 +9387,7 @@ static void ggml_compute_forward_gla_f32(
94679387

94689388
#ifdef GLA_VECTOR_SIZE
94699389
int gla_vector_size;
9470-
#if defined(__ARM_FEATURE_SVE)
9471-
gla_vector_size = svcntw();
9472-
#else
9473-
gla_vector_size = GLA_VECTOR_SIZE;
9474-
#endif
9390+
gla_vector_size = GLA_VECTOR_SIZE;
94759391
const int64_t vec_count = head_size / gla_vector_size;
94769392

94779393
for (int64_t t = 0; t < T; t++) {
@@ -9631,127 +9547,84 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
96319547
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
96329548
int64_t h_stride_2d = head_size * head_size;
96339549

9634-
#if defined(GGML_SIMD)
9635-
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9636-
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9637-
for (int64_t t = 0; t < T; t++) {
9638-
int64_t t_offset = t * t_stride;
9639-
int64_t state_offset = head_size * C * (t / (T / n_seqs));
9640-
float * state_cur = state + state_offset;
9641-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9642-
9643-
for (int64_t h = h_start; h < h_end; h++) {
9644-
int64_t h_offset = h * h_stride;
9645-
int64_t t_h_offset = t_offset + h_offset;
9646-
int64_t h_2d_offset = h * h_stride_2d;
9647-
9648-
for (int64_t i = 0; i < head_size; i++) {
9649-
int64_t t_h_i_offset = t_h_offset + i;
9650-
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
9651-
9652-
float v_val = v[t_h_i_offset];
9653-
9654-
float sa = 0, result = 0;
9655-
for (int64_t j = 0; j < head_size; j++) {
9656-
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
9657-
}
9550+
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
9551+
for (int64_t t = 0; t < T; t++) {
9552+
int64_t t_offset = t * t_stride;
9553+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
9554+
float * state_cur = state + state_offset;
9555+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
96589556

9659-
for (int64_t j = 0; j < head_size; j++) {
9660-
int64_t t_h_j_offset = t_h_offset + j;
9661-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9662-
9663-
float r_val = r[t_h_j_offset];
9664-
float w_val = w[t_h_j_offset];
9665-
float k_val = k[t_h_j_offset];
9666-
float b_val = b[t_h_j_offset];
9667-
float kv_val = v_val * k_val;
9668-
float prev_state_val = state_prev[h_2d_i_j_offset];
9669-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9670-
result += state_cur[h_2d_i_j_offset] * r_val;
9671-
}
9672-
dst_data[t_h_i_offset] = result;
9673-
}
9674-
}
9675-
}
9676-
#else
9677-
for (int64_t t = 0; t < T; t++) {
9678-
int64_t t_offset = t * t_stride;
9679-
int64_t state_offset = head_size * C * (t / (T / n_seqs));
9680-
float * state_cur = state + state_offset;
9681-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9682-
9683-
for (int64_t h = h_start; h < h_end; h++) {
9684-
int64_t h_offset = h * h_stride;
9685-
int64_t t_h_offset = t_offset + h_offset;
9686-
int64_t h_2d_offset = h * h_stride_2d;
9687-
9688-
for (int64_t ii = 0; ii < head_size; ii++) {
9689-
int64_t t_h_i_offset = t_h_offset + ii;
9690-
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9691-
9692-
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
9693-
9694-
float sa = 0;
9695-
{
9696-
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9697-
GGML_F32_VEC ax[GGML_F32_ARR];
9698-
GGML_F32_VEC ay[GGML_F32_ARR];
9699-
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
9700-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9701-
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
9702-
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9703-
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
9704-
}
9557+
for (int64_t h = h_start; h < h_end; h++) {
9558+
int64_t h_offset = h * h_stride;
9559+
int64_t t_h_offset = t_offset + h_offset;
9560+
int64_t h_2d_offset = h * h_stride_2d;
9561+
9562+
for (int64_t ii = 0; ii < head_size; ii++) {
9563+
int64_t t_h_i_offset = t_h_offset + ii;
9564+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9565+
9566+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
9567+
9568+
float sa = 0;
9569+
{
9570+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9571+
GGML_F32_VEC ax[GGML_F32_ARR];
9572+
GGML_F32_VEC ay[GGML_F32_ARR];
9573+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
9574+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9575+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
9576+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9577+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
97059578
}
9706-
GGML_F32_VEC_REDUCE(sa, sum);
97079579
}
9580+
GGML_F32_VEC_REDUCE(sa, sum);
9581+
}
97089582

9709-
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
9583+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
97109584

9711-
int64_t j = 0;
9712-
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9713-
for (; j < head_size; j += GGML_F32_STEP) {
9714-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9715-
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9716-
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
9585+
int64_t j = 0;
9586+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9587+
for (; j < head_size; j += GGML_F32_STEP) {
9588+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9589+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9590+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
97179591

9718-
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
9719-
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
9720-
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
9721-
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
9592+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
9593+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
9594+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
9595+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
97229596

9723-
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
9597+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
97249598

9725-
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
9726-
// kv + s * decay + sa * b
9727-
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
9728-
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
9729-
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
9599+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
9600+
// kv + s * decay + sa * b
9601+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
9602+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
9603+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
97309604

9731-
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
9732-
}
9733-
}
9734-
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
9735-
9736-
// There shouldn't be left-overs though.
9737-
for (; j < head_size; j++) {
9738-
int64_t t_h_j_offset = t_h_offset + j;
9739-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9740-
9741-
float r_val = r[t_h_j_offset];
9742-
float w_val = w[t_h_j_offset];
9743-
float k_val = k[t_h_j_offset];
9744-
float b_val = b[t_h_j_offset];
9745-
float kv_val = v[t_h_i_offset] * k_val;
9746-
9747-
float prev_state_val = state_prev[h_2d_i_j_offset];
9748-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9749-
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9605+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
97509606
}
97519607
}
9608+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
9609+
9610+
// There shouldn't be left-overs though.
9611+
for (; j < head_size; j++) {
9612+
int64_t t_h_j_offset = t_h_offset + j;
9613+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9614+
9615+
float r_val = r[t_h_j_offset];
9616+
float w_val = w[t_h_j_offset];
9617+
float k_val = k[t_h_j_offset];
9618+
float b_val = b[t_h_j_offset];
9619+
float kv_val = v[t_h_i_offset] * k_val;
9620+
9621+
float prev_state_val = state_prev[h_2d_i_j_offset];
9622+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9623+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9624+
}
97529625
}
97539626
}
9754-
#endif
9627+
}
97559628
#else
97569629
for (int64_t t = 0; t < T; t++) {
97579630
int64_t t_offset = t * t_stride;

0 commit comments

Comments
 (0)