@@ -8646,41 +8646,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8646
8646
const int ii = i1 + h*nr;
8647
8647
const float x_dt = x[ii] * dt_soft_plus;
8648
8648
float sumf = 0 .0f ;
8649
- #if defined(GGML_SIMD)
8650
- #if defined(__ARM_FEATURE_SVE)
8651
- const int ggml_f32_epr = svcntw ();
8652
- const int ggml_f32_step = 1 * ggml_f32_epr;
8653
-
8654
- const int np = (nc & ~(ggml_f32_step - 1 ));
8655
-
8656
- GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8657
-
8658
- GGML_F32_VEC adA = GGML_F32_VEC_SET1 (dA);
8659
- GGML_F32_VEC axdt = GGML_F32_VEC_SET1 (x_dt);
8660
-
8661
- for (int i = 0 ; i < np; i += ggml_f32_step) {
8662
- // TODO: maybe unroll more?
8663
- for (int j = 0 ; j < 1 ; j++) {
8664
- GGML_F32_VEC t0 = GGML_F32_VEC_LOAD (s0 + i + j*ggml_f32_epr + ii*nc);
8665
- GGML_F32_VEC t1 = GGML_F32_VEC_LOAD (B + i + j*ggml_f32_epr + g*nc);
8666
- GGML_F32_VEC t2 = GGML_F32_VEC_LOAD (C + i + j*ggml_f32_epr + g*nc);
8667
-
8668
- t0 = GGML_F32_VEC_MUL (t0, adA);
8669
- t1 = GGML_F32_VEC_MUL (t1, axdt);
8670
-
8671
- t0 = GGML_F32_VEC_ADD (t0, t1);
8672
-
8673
- sum = GGML_F32_VEC_FMA (sum, t0, t2);
8674
-
8675
- GGML_F32_VEC_STORE (s + i + j*ggml_f32_epr + ii*nc, t0);
8676
- }
8677
- }
8678
-
8679
- sumf = GGML_F32xt_REDUCE_ONE (sum);
8680
- #elif defined(__riscv_v_intrinsic)
8681
- // todo: RVV implementation
8682
- const int np = 0 ;
8683
- #else
8649
+ #if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
8684
8650
const int np = (nc & ~(GGML_F32_STEP - 1 ));
8685
8651
8686
8652
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -8711,7 +8677,6 @@ static void ggml_compute_forward_ssm_scan_f32(
8711
8677
8712
8678
// reduce sum0..sum3 to sum0
8713
8679
GGML_F32_VEC_REDUCE (sumf, sum);
8714
- #endif
8715
8680
#else
8716
8681
const int np = 0 ;
8717
8682
#endif
@@ -8741,30 +8706,6 @@ static void ggml_compute_forward_ssm_scan_f32(
8741
8706
for (int i1 = 0 ; i1 < nr; ++i1) {
8742
8707
const int ii = i1 + h*nr;
8743
8708
const float x_dt = x[ii] * dt_soft_plus;
8744
- #if defined(__ARM_FEATURE_SVE)
8745
- svfloat32_t vx_dt = GGML_F32_VEC_SET1 (x_dt);
8746
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1 (dt_soft_plus);
8747
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8748
-
8749
- // d_state
8750
- // TODO: what happens when (d_state % svcntw()) != 0?
8751
- for (int64_t k = 0 ; k < nc; k += svcntw ()) {
8752
- svfloat32_t vA = GGML_F32_VEC_LOAD (&A[h*nc + k]);
8753
- svfloat32_t vB = GGML_F32_VEC_LOAD (&B[k + g*nc]);
8754
- svfloat32_t vC = GGML_F32_VEC_LOAD (&C[k + g*nc]);
8755
- svfloat32_t vs0 = GGML_F32_VEC_LOAD (&s0[ii*nc + k]);
8756
-
8757
- svfloat32_t t1 = GGML_F32_VEC_MUL (vdt_soft_plus, vA);
8758
- t1 = exp_ps_sve (svptrue_b32 (), t1);
8759
- svfloat32_t t2 = GGML_F32_VEC_MUL (vx_dt, vB);
8760
-
8761
- vs0 = GGML_F32_VEC_FMA (t2, vs0, t1);
8762
- r1_vector = GGML_F32_VEC_ADD (GGML_F32_VEC_MUL (vs0, vC), r1_vector);
8763
-
8764
- GGML_F32_VEC_STORE (&s[ii*nc + k], vs0);
8765
- }
8766
- y[ii] = GGML_F32xt_REDUCE_ONE (r1_vector);
8767
- #else
8768
8709
float sumf = 0 .0f ;
8769
8710
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
8770
8711
// and also because expf is used within the loop.
@@ -8779,7 +8720,6 @@ static void ggml_compute_forward_ssm_scan_f32(
8779
8720
s[i] = state;
8780
8721
}
8781
8722
y[ii] = sumf;
8782
- #endif
8783
8723
}
8784
8724
}
8785
8725
}
@@ -9231,14 +9171,6 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
9231
9171
#define GGML_F32X_MUL GGML_F32x16_MUL
9232
9172
#define GGML_F32X_FMA GGML_F32x16_FMA
9233
9173
#define WKV_VECTOR_SIZE 16
9234
- #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9235
- #define GGML_F32X GGML_F32xt
9236
- #define GGML_F32X_SET1 GGML_F32xt_SET1
9237
- #define GGML_F32X_LOAD GGML_F32xt_LOAD
9238
- #define GGML_F32X_STORE GGML_F32xt_STORE
9239
- #define GGML_F32X_MUL GGML_F32xt_MUL
9240
- #define GGML_F32X_FMA GGML_F32xt_FMA
9241
- #define WKV_VECTOR_SIZE 8
9242
9174
#elif defined(__ARM_NEON) && defined(__aarch64__)
9243
9175
#define GGML_F32X GGML_F32x4
9244
9176
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -9251,11 +9183,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
9251
9183
9252
9184
#ifdef WKV_VECTOR_SIZE
9253
9185
int wkv_vector_size;
9254
- #if defined(__ARM_FEATURE_SVE)
9255
- wkv_vector_size = svcntw ();
9256
- #else
9257
- wkv_vector_size = WKV_VECTOR_SIZE;
9258
- #endif
9186
+ wkv_vector_size = WKV_VECTOR_SIZE;
9259
9187
const int64_t vec_count = head_size / wkv_vector_size;
9260
9188
9261
9189
for (int64_t t = 0 ; t < T; t++) {
@@ -9447,14 +9375,6 @@ static void ggml_compute_forward_gla_f32(
9447
9375
#define GGML_F32X_MUL GGML_F32x16_MUL
9448
9376
#define GGML_F32X_FMA GGML_F32x16_FMA
9449
9377
#define GLA_VECTOR_SIZE 16
9450
- #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9451
- #define GGML_F32X GGML_F32xt
9452
- #define GGML_F32X_SET1 GGML_F32xt_SET1
9453
- #define GGML_F32X_LOAD GGML_F32xt_LOAD
9454
- #define GGML_F32X_STORE GGML_F32xt_STORE
9455
- #define GGML_F32X_MUL GGML_F32xt_MUL
9456
- #define GGML_F32X_FMA GGML_F32xt_FMA
9457
- #define GLA_VECTOR_SIZE 8
9458
9378
#elif defined(__ARM_NEON) && defined(__aarch64__)
9459
9379
#define GGML_F32X GGML_F32x4
9460
9380
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -9467,11 +9387,7 @@ static void ggml_compute_forward_gla_f32(
9467
9387
9468
9388
#ifdef GLA_VECTOR_SIZE
9469
9389
int gla_vector_size;
9470
- #if defined(__ARM_FEATURE_SVE)
9471
- gla_vector_size = svcntw ();
9472
- #else
9473
- gla_vector_size = GLA_VECTOR_SIZE;
9474
- #endif
9390
+ gla_vector_size = GLA_VECTOR_SIZE;
9475
9391
const int64_t vec_count = head_size / gla_vector_size;
9476
9392
9477
9393
for (int64_t t = 0 ; t < T; t++) {
@@ -9631,127 +9547,84 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
9631
9547
GGML_ASSERT (C % HEADS == 0 ); // C must be divisible by HEADS
9632
9548
int64_t h_stride_2d = head_size * head_size;
9633
9549
9634
- #if defined(GGML_SIMD)
9635
- #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9636
- // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9637
- for (int64_t t = 0 ; t < T; t++) {
9638
- int64_t t_offset = t * t_stride;
9639
- int64_t state_offset = head_size * C * (t / (T / n_seqs));
9640
- float * state_cur = state + state_offset;
9641
- float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
9642
-
9643
- for (int64_t h = h_start; h < h_end; h++) {
9644
- int64_t h_offset = h * h_stride;
9645
- int64_t t_h_offset = t_offset + h_offset;
9646
- int64_t h_2d_offset = h * h_stride_2d;
9647
-
9648
- for (int64_t i = 0 ; i < head_size; i++) {
9649
- int64_t t_h_i_offset = t_h_offset + i;
9650
- int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
9651
-
9652
- float v_val = v[t_h_i_offset];
9653
-
9654
- float sa = 0 , result = 0 ;
9655
- for (int64_t j = 0 ; j < head_size; j++) {
9656
- sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
9657
- }
9550
+ #if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
9551
+ for (int64_t t = 0 ; t < T; t++) {
9552
+ int64_t t_offset = t * t_stride;
9553
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
9554
+ float * state_cur = state + state_offset;
9555
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
9658
9556
9659
- for (int64_t j = 0 ; j < head_size; j++) {
9660
- int64_t t_h_j_offset = t_h_offset + j;
9661
- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9662
-
9663
- float r_val = r[t_h_j_offset];
9664
- float w_val = w[t_h_j_offset];
9665
- float k_val = k[t_h_j_offset];
9666
- float b_val = b[t_h_j_offset];
9667
- float kv_val = v_val * k_val;
9668
- float prev_state_val = state_prev[h_2d_i_j_offset];
9669
- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9670
- result += state_cur[h_2d_i_j_offset] * r_val;
9671
- }
9672
- dst_data[t_h_i_offset] = result;
9673
- }
9674
- }
9675
- }
9676
- #else
9677
- for (int64_t t = 0 ; t < T; t++) {
9678
- int64_t t_offset = t * t_stride;
9679
- int64_t state_offset = head_size * C * (t / (T / n_seqs));
9680
- float * state_cur = state + state_offset;
9681
- float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
9682
-
9683
- for (int64_t h = h_start; h < h_end; h++) {
9684
- int64_t h_offset = h * h_stride;
9685
- int64_t t_h_offset = t_offset + h_offset;
9686
- int64_t h_2d_offset = h * h_stride_2d;
9687
-
9688
- for (int64_t ii = 0 ; ii < head_size; ii++) {
9689
- int64_t t_h_i_offset = t_h_offset + ii;
9690
- int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9691
-
9692
- GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
9693
-
9694
- float sa = 0 ;
9695
- {
9696
- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9697
- GGML_F32_VEC ax[GGML_F32_ARR];
9698
- GGML_F32_VEC ay[GGML_F32_ARR];
9699
- for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
9700
- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9701
- ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
9702
- ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9703
- sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
9704
- }
9557
+ for (int64_t h = h_start; h < h_end; h++) {
9558
+ int64_t h_offset = h * h_stride;
9559
+ int64_t t_h_offset = t_offset + h_offset;
9560
+ int64_t h_2d_offset = h * h_stride_2d;
9561
+
9562
+ for (int64_t ii = 0 ; ii < head_size; ii++) {
9563
+ int64_t t_h_i_offset = t_h_offset + ii;
9564
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9565
+
9566
+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
9567
+
9568
+ float sa = 0 ;
9569
+ {
9570
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9571
+ GGML_F32_VEC ax[GGML_F32_ARR];
9572
+ GGML_F32_VEC ay[GGML_F32_ARR];
9573
+ for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
9574
+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9575
+ ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
9576
+ ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9577
+ sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
9705
9578
}
9706
- GGML_F32_VEC_REDUCE (sa, sum);
9707
9579
}
9580
+ GGML_F32_VEC_REDUCE (sa, sum);
9581
+ }
9708
9582
9709
- GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
9583
+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
9710
9584
9711
- int64_t j = 0 ;
9712
- GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9713
- for (; j < head_size; j += GGML_F32_STEP) {
9714
- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9715
- int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9716
- int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
9585
+ int64_t j = 0 ;
9586
+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9587
+ for (; j < head_size; j += GGML_F32_STEP) {
9588
+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9589
+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9590
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
9717
9591
9718
- GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
9719
- GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
9720
- GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
9721
- GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
9592
+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
9593
+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
9594
+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
9595
+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
9722
9596
9723
- k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
9597
+ k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
9724
9598
9725
- GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
9726
- // kv + s * decay + sa * b
9727
- state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
9728
- state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
9729
- GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
9599
+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
9600
+ // kv + s * decay + sa * b
9601
+ state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
9602
+ state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
9603
+ GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
9730
9604
9731
- result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
9732
- }
9733
- }
9734
- GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
9735
-
9736
- // There shouldn't be left-overs though.
9737
- for (; j < head_size; j++) {
9738
- int64_t t_h_j_offset = t_h_offset + j;
9739
- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9740
-
9741
- float r_val = r[t_h_j_offset];
9742
- float w_val = w[t_h_j_offset];
9743
- float k_val = k[t_h_j_offset];
9744
- float b_val = b[t_h_j_offset];
9745
- float kv_val = v[t_h_i_offset] * k_val;
9746
-
9747
- float prev_state_val = state_prev[h_2d_i_j_offset];
9748
- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9749
- dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9605
+ result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
9750
9606
}
9751
9607
}
9608
+ GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
9609
+
9610
+ // There shouldn't be left-overs though.
9611
+ for (; j < head_size; j++) {
9612
+ int64_t t_h_j_offset = t_h_offset + j;
9613
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9614
+
9615
+ float r_val = r[t_h_j_offset];
9616
+ float w_val = w[t_h_j_offset];
9617
+ float k_val = k[t_h_j_offset];
9618
+ float b_val = b[t_h_j_offset];
9619
+ float kv_val = v[t_h_i_offset] * k_val;
9620
+
9621
+ float prev_state_val = state_prev[h_2d_i_j_offset];
9622
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9623
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9624
+ }
9752
9625
}
9753
9626
}
9754
- # endif
9627
+ }
9755
9628
#else
9756
9629
for (int64_t t = 0 ; t < T; t++) {
9757
9630
int64_t t_offset = t * t_stride;
0 commit comments