@@ -45,12 +45,22 @@ struct ymm_vector<float> {
45
45
{
46
46
return _mm256_set1_ps (type_max ());
47
47
}
48
-
49
48
static zmmi_t
50
49
seti (int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
51
50
{
52
51
return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
53
52
}
53
+ static reg_t set (type_t v1,
54
+ type_t v2,
55
+ type_t v3,
56
+ type_t v4,
57
+ type_t v5,
58
+ type_t v6,
59
+ type_t v7,
60
+ type_t v8)
61
+ {
62
+ return _mm256_set_ps (v1, v2, v3, v4, v5, v6, v7, v8);
63
+ }
54
64
static opmask_t kxor_opmask (opmask_t x, opmask_t y)
55
65
{
56
66
return _kxor_mask8 (x, y);
@@ -86,10 +96,16 @@ struct ymm_vector<float> {
86
96
{
87
97
return _mm512_mask_i64gather_ps (src, mask, index, base, scale);
88
98
}
89
- template <int scale>
90
- static reg_t i64gather (__m512i index, void const *base)
99
+ static reg_t i64gather (type_t *arr, int64_t *ind)
91
100
{
92
- return _mm512_i64gather_ps (index, base, scale);
101
+ return set (arr[ind[7 ]],
102
+ arr[ind[6 ]],
103
+ arr[ind[5 ]],
104
+ arr[ind[4 ]],
105
+ arr[ind[3 ]],
106
+ arr[ind[2 ]],
107
+ arr[ind[1 ]],
108
+ arr[ind[0 ]]);
93
109
}
94
110
static reg_t loadu (void const *mem)
95
111
{
@@ -195,6 +211,17 @@ struct ymm_vector<uint32_t> {
195
211
{
196
212
return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
197
213
}
214
+ static reg_t set (type_t v1,
215
+ type_t v2,
216
+ type_t v3,
217
+ type_t v4,
218
+ type_t v5,
219
+ type_t v6,
220
+ type_t v7,
221
+ type_t v8)
222
+ {
223
+ return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
224
+ }
198
225
static opmask_t kxor_opmask (opmask_t x, opmask_t y)
199
226
{
200
227
return _kxor_mask8 (x, y);
@@ -221,10 +248,16 @@ struct ymm_vector<uint32_t> {
221
248
{
222
249
return _mm512_mask_i64gather_epi32 (src, mask, index, base, scale);
223
250
}
224
- template <int scale>
225
- static reg_t i64gather (__m512i index, void const *base)
251
+ static reg_t i64gather (type_t *arr, int64_t *ind)
226
252
{
227
- return _mm512_i64gather_epi32 (index, base, scale);
253
+ return set (arr[ind[7 ]],
254
+ arr[ind[6 ]],
255
+ arr[ind[5 ]],
256
+ arr[ind[4 ]],
257
+ arr[ind[3 ]],
258
+ arr[ind[2 ]],
259
+ arr[ind[1 ]],
260
+ arr[ind[0 ]]);
228
261
}
229
262
static reg_t loadu (void const *mem)
230
263
{
@@ -324,6 +357,17 @@ struct ymm_vector<int32_t> {
324
357
{
325
358
return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
326
359
}
360
+ static reg_t set (type_t v1,
361
+ type_t v2,
362
+ type_t v3,
363
+ type_t v4,
364
+ type_t v5,
365
+ type_t v6,
366
+ type_t v7,
367
+ type_t v8)
368
+ {
369
+ return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
370
+ }
327
371
static opmask_t kxor_opmask (opmask_t x, opmask_t y)
328
372
{
329
373
return _kxor_mask8 (x, y);
@@ -350,10 +394,16 @@ struct ymm_vector<int32_t> {
350
394
{
351
395
return _mm512_mask_i64gather_epi32 (src, mask, index, base, scale);
352
396
}
353
- template <int scale>
354
- static reg_t i64gather (__m512i index, void const *base)
397
+ static reg_t i64gather (type_t *arr, int64_t *ind)
355
398
{
356
- return _mm512_i64gather_epi32 (index, base, scale);
399
+ return set (arr[ind[7 ]],
400
+ arr[ind[6 ]],
401
+ arr[ind[5 ]],
402
+ arr[ind[4 ]],
403
+ arr[ind[3 ]],
404
+ arr[ind[2 ]],
405
+ arr[ind[1 ]],
406
+ arr[ind[0 ]]);
357
407
}
358
408
static reg_t loadu (void const *mem)
359
409
{
@@ -456,6 +506,17 @@ struct zmm_vector<int64_t> {
456
506
{
457
507
return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
458
508
}
509
+ static reg_t set (type_t v1,
510
+ type_t v2,
511
+ type_t v3,
512
+ type_t v4,
513
+ type_t v5,
514
+ type_t v6,
515
+ type_t v7,
516
+ type_t v8)
517
+ {
518
+ return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
519
+ }
459
520
static opmask_t kxor_opmask (opmask_t x, opmask_t y)
460
521
{
461
522
return _kxor_mask8 (x, y);
@@ -482,10 +543,16 @@ struct zmm_vector<int64_t> {
482
543
{
483
544
return _mm512_mask_i64gather_epi64 (src, mask, index, base, scale);
484
545
}
485
- template <int scale>
486
- static reg_t i64gather (__m512i index, void const *base)
546
+ static reg_t i64gather (type_t *arr, int64_t *ind)
487
547
{
488
- return _mm512_i64gather_epi64 (index, base, scale);
548
+ return set (arr[ind[7 ]],
549
+ arr[ind[6 ]],
550
+ arr[ind[5 ]],
551
+ arr[ind[4 ]],
552
+ arr[ind[3 ]],
553
+ arr[ind[2 ]],
554
+ arr[ind[1 ]],
555
+ arr[ind[0 ]]);
489
556
}
490
557
static reg_t loadu (void const *mem)
491
558
{
@@ -589,16 +656,33 @@ struct zmm_vector<uint64_t> {
589
656
{
590
657
return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
591
658
}
659
+ static reg_t set (type_t v1,
660
+ type_t v2,
661
+ type_t v3,
662
+ type_t v4,
663
+ type_t v5,
664
+ type_t v6,
665
+ type_t v7,
666
+ type_t v8)
667
+ {
668
+ return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
669
+ }
592
670
template <int scale>
593
671
static reg_t
594
672
mask_i64gather (reg_t src, opmask_t mask, __m512i index, void const *base)
595
673
{
596
674
return _mm512_mask_i64gather_epi64 (src, mask, index, base, scale);
597
675
}
598
- template <int scale>
599
- static reg_t i64gather (__m512i index, void const *base)
676
+ static reg_t i64gather (type_t *arr, int64_t *ind)
600
677
{
601
- return _mm512_i64gather_epi64 (index, base, scale);
678
+ return set (arr[ind[7 ]],
679
+ arr[ind[6 ]],
680
+ arr[ind[5 ]],
681
+ arr[ind[4 ]],
682
+ arr[ind[3 ]],
683
+ arr[ind[2 ]],
684
+ arr[ind[1 ]],
685
+ arr[ind[0 ]]);
602
686
}
603
687
static opmask_t knot_opmask (opmask_t x)
604
688
{
@@ -704,13 +788,22 @@ struct zmm_vector<double> {
704
788
{
705
789
return _mm512_set1_pd (type_max ());
706
790
}
707
-
708
791
static zmmi_t
709
792
seti (int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
710
793
{
711
794
return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
712
795
}
713
-
796
+ static reg_t set (type_t v1,
797
+ type_t v2,
798
+ type_t v3,
799
+ type_t v4,
800
+ type_t v5,
801
+ type_t v6,
802
+ type_t v7,
803
+ type_t v8)
804
+ {
805
+ return _mm512_set_pd (v1, v2, v3, v4, v5, v6, v7, v8);
806
+ }
714
807
static reg_t maskz_loadu (opmask_t mask, void const *mem)
715
808
{
716
809
return _mm512_maskz_loadu_pd (mask, mem);
@@ -742,10 +835,16 @@ struct zmm_vector<double> {
742
835
{
743
836
return _mm512_mask_i64gather_pd (src, mask, index, base, scale);
744
837
}
745
- template <int scale>
746
- static reg_t i64gather (__m512i index, void const *base)
838
+ static reg_t i64gather (type_t *arr, int64_t *ind)
747
839
{
748
- return _mm512_i64gather_pd (index, base, scale);
840
+ return set (arr[ind[7 ]],
841
+ arr[ind[6 ]],
842
+ arr[ind[5 ]],
843
+ arr[ind[4 ]],
844
+ arr[ind[3 ]],
845
+ arr[ind[2 ]],
846
+ arr[ind[1 ]],
847
+ arr[ind[0 ]]);
749
848
}
750
849
static reg_t loadu (void const *mem)
751
850
{
@@ -841,7 +940,6 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm)
841
940
template <typename vtype, typename reg_t = typename vtype::reg_t >
842
941
X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit (reg_t zmm)
843
942
{
844
-
845
943
// 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
846
944
zmm = cmp_merge<vtype>(
847
945
zmm,
0 commit comments