Skip to content

Commit f2ca68b

Browse files
author
Raghuveer Devulapalli
committed
Use scalar emulation of gather instruction for arg methods
1 parent 1b39637 commit f2ca68b

File tree

4 files changed

+169
-83
lines changed

4 files changed

+169
-83
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N)
8585
typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01;
8686
argzmm_t argzmm1 = argtype::loadu(arg);
8787
argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8);
88-
reg_t arrzmm1 = vtype::template i64gather<sizeof(type_t)>(argzmm1, arr);
88+
reg_t arrzmm1 = vtype::i64gather(arr, arg);
8989
reg_t arrzmm2 = vtype::template mask_i64gather<sizeof(type_t)>(
9090
vtype::zmm_max(), load_mask, argzmm2, arr);
9191
arrzmm1 = sort_zmm_64bit<vtype, argtype>(arrzmm1, argzmm1);
@@ -111,7 +111,7 @@ X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N)
111111
X86_SIMD_SORT_UNROLL_LOOP(2)
112112
for (int ii = 0; ii < 2; ++ii) {
113113
argzmm[ii] = argtype::loadu(arg + 8 * ii);
114-
arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
114+
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
115115
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
116116
}
117117

@@ -154,7 +154,7 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N)
154154
X86_SIMD_SORT_UNROLL_LOOP(4)
155155
for (int ii = 0; ii < 4; ++ii) {
156156
argzmm[ii] = argtype::loadu(arg + 8 * ii);
157-
arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
157+
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
158158
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
159159
}
160160

@@ -206,7 +206,7 @@ X86_SIMD_SORT_UNROLL_LOOP(4)
206206
//X86_SIMD_SORT_UNROLL_LOOP(8)
207207
// for (int ii = 0; ii < 8; ++ii) {
208208
// argzmm[ii] = argtype::loadu(arg + 8*ii);
209-
// arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
209+
// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr);
210210
// arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
211211
// }
212212
//
@@ -257,17 +257,14 @@ type_t get_pivot_64bit(type_t *arr,
257257
// median of 8
258258
int64_t size = (right - left) / 8;
259259
using reg_t = typename vtype::reg_t;
260-
// TODO: Use gather here too:
261-
__m512i rand_index = _mm512_set_epi64(arg[left + size],
262-
arg[left + 2 * size],
263-
arg[left + 3 * size],
264-
arg[left + 4 * size],
265-
arg[left + 5 * size],
266-
arg[left + 6 * size],
267-
arg[left + 7 * size],
268-
arg[left + 8 * size]);
269-
reg_t rand_vec
270-
= vtype::template i64gather<sizeof(type_t)>(rand_index, arr);
260+
reg_t rand_vec = vtype::set(arr[arg[left + size]],
261+
arr[arg[left + 2 * size]],
262+
arr[arg[left + 3 * size]],
263+
arr[arg[left + 4 * size]],
264+
arr[arg[left + 5 * size]],
265+
arr[arg[left + 6 * size]],
266+
arr[arg[left + 7 * size]],
267+
arr[arg[left + 8 * size]]);
271268
// pivot will never be a nan, since there are no nan's!
272269
reg_t sort = sort_zmm_64bit<vtype>(rand_vec);
273270
return ((type_t *)&sort)[4];

src/avx512-64bit-common.h

Lines changed: 120 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,22 @@ struct ymm_vector<float> {
4545
{
4646
return _mm256_set1_ps(type_max());
4747
}
48-
4948
static zmmi_t
5049
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
5150
{
5251
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
5352
}
53+
static reg_t set(type_t v1,
54+
type_t v2,
55+
type_t v3,
56+
type_t v4,
57+
type_t v5,
58+
type_t v6,
59+
type_t v7,
60+
type_t v8)
61+
{
62+
return _mm256_set_ps(v1, v2, v3, v4, v5, v6, v7, v8);
63+
}
5464
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
5565
{
5666
return _kxor_mask8(x, y);
@@ -86,10 +96,16 @@ struct ymm_vector<float> {
8696
{
8797
return _mm512_mask_i64gather_ps(src, mask, index, base, scale);
8898
}
89-
template <int scale>
90-
static reg_t i64gather(__m512i index, void const *base)
99+
static reg_t i64gather(type_t *arr, int64_t *ind)
91100
{
92-
return _mm512_i64gather_ps(index, base, scale);
101+
return set(arr[ind[7]],
102+
arr[ind[6]],
103+
arr[ind[5]],
104+
arr[ind[4]],
105+
arr[ind[3]],
106+
arr[ind[2]],
107+
arr[ind[1]],
108+
arr[ind[0]]);
93109
}
94110
static reg_t loadu(void const *mem)
95111
{
@@ -195,6 +211,17 @@ struct ymm_vector<uint32_t> {
195211
{
196212
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
197213
}
214+
static reg_t set(type_t v1,
215+
type_t v2,
216+
type_t v3,
217+
type_t v4,
218+
type_t v5,
219+
type_t v6,
220+
type_t v7,
221+
type_t v8)
222+
{
223+
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
224+
}
198225
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
199226
{
200227
return _kxor_mask8(x, y);
@@ -221,10 +248,16 @@ struct ymm_vector<uint32_t> {
221248
{
222249
return _mm512_mask_i64gather_epi32(src, mask, index, base, scale);
223250
}
224-
template <int scale>
225-
static reg_t i64gather(__m512i index, void const *base)
251+
static reg_t i64gather(type_t *arr, int64_t *ind)
226252
{
227-
return _mm512_i64gather_epi32(index, base, scale);
253+
return set(arr[ind[7]],
254+
arr[ind[6]],
255+
arr[ind[5]],
256+
arr[ind[4]],
257+
arr[ind[3]],
258+
arr[ind[2]],
259+
arr[ind[1]],
260+
arr[ind[0]]);
228261
}
229262
static reg_t loadu(void const *mem)
230263
{
@@ -324,6 +357,17 @@ struct ymm_vector<int32_t> {
324357
{
325358
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
326359
}
360+
static reg_t set(type_t v1,
361+
type_t v2,
362+
type_t v3,
363+
type_t v4,
364+
type_t v5,
365+
type_t v6,
366+
type_t v7,
367+
type_t v8)
368+
{
369+
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
370+
}
327371
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
328372
{
329373
return _kxor_mask8(x, y);
@@ -350,10 +394,16 @@ struct ymm_vector<int32_t> {
350394
{
351395
return _mm512_mask_i64gather_epi32(src, mask, index, base, scale);
352396
}
353-
template <int scale>
354-
static reg_t i64gather(__m512i index, void const *base)
397+
static reg_t i64gather(type_t *arr, int64_t *ind)
355398
{
356-
return _mm512_i64gather_epi32(index, base, scale);
399+
return set(arr[ind[7]],
400+
arr[ind[6]],
401+
arr[ind[5]],
402+
arr[ind[4]],
403+
arr[ind[3]],
404+
arr[ind[2]],
405+
arr[ind[1]],
406+
arr[ind[0]]);
357407
}
358408
static reg_t loadu(void const *mem)
359409
{
@@ -456,6 +506,17 @@ struct zmm_vector<int64_t> {
456506
{
457507
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
458508
}
509+
static reg_t set(type_t v1,
510+
type_t v2,
511+
type_t v3,
512+
type_t v4,
513+
type_t v5,
514+
type_t v6,
515+
type_t v7,
516+
type_t v8)
517+
{
518+
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
519+
}
459520
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
460521
{
461522
return _kxor_mask8(x, y);
@@ -482,10 +543,16 @@ struct zmm_vector<int64_t> {
482543
{
483544
return _mm512_mask_i64gather_epi64(src, mask, index, base, scale);
484545
}
485-
template <int scale>
486-
static reg_t i64gather(__m512i index, void const *base)
546+
static reg_t i64gather(type_t *arr, int64_t *ind)
487547
{
488-
return _mm512_i64gather_epi64(index, base, scale);
548+
return set(arr[ind[7]],
549+
arr[ind[6]],
550+
arr[ind[5]],
551+
arr[ind[4]],
552+
arr[ind[3]],
553+
arr[ind[2]],
554+
arr[ind[1]],
555+
arr[ind[0]]);
489556
}
490557
static reg_t loadu(void const *mem)
491558
{
@@ -589,16 +656,33 @@ struct zmm_vector<uint64_t> {
589656
{
590657
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
591658
}
659+
static reg_t set(type_t v1,
660+
type_t v2,
661+
type_t v3,
662+
type_t v4,
663+
type_t v5,
664+
type_t v6,
665+
type_t v7,
666+
type_t v8)
667+
{
668+
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
669+
}
592670
template <int scale>
593671
static reg_t
594672
mask_i64gather(reg_t src, opmask_t mask, __m512i index, void const *base)
595673
{
596674
return _mm512_mask_i64gather_epi64(src, mask, index, base, scale);
597675
}
598-
template <int scale>
599-
static reg_t i64gather(__m512i index, void const *base)
676+
static reg_t i64gather(type_t *arr, int64_t *ind)
600677
{
601-
return _mm512_i64gather_epi64(index, base, scale);
678+
return set(arr[ind[7]],
679+
arr[ind[6]],
680+
arr[ind[5]],
681+
arr[ind[4]],
682+
arr[ind[3]],
683+
arr[ind[2]],
684+
arr[ind[1]],
685+
arr[ind[0]]);
602686
}
603687
static opmask_t knot_opmask(opmask_t x)
604688
{
@@ -704,13 +788,22 @@ struct zmm_vector<double> {
704788
{
705789
return _mm512_set1_pd(type_max());
706790
}
707-
708791
static zmmi_t
709792
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
710793
{
711794
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
712795
}
713-
796+
static reg_t set(type_t v1,
797+
type_t v2,
798+
type_t v3,
799+
type_t v4,
800+
type_t v5,
801+
type_t v6,
802+
type_t v7,
803+
type_t v8)
804+
{
805+
return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8);
806+
}
714807
static reg_t maskz_loadu(opmask_t mask, void const *mem)
715808
{
716809
return _mm512_maskz_loadu_pd(mask, mem);
@@ -742,10 +835,16 @@ struct zmm_vector<double> {
742835
{
743836
return _mm512_mask_i64gather_pd(src, mask, index, base, scale);
744837
}
745-
template <int scale>
746-
static reg_t i64gather(__m512i index, void const *base)
838+
static reg_t i64gather(type_t *arr, int64_t *ind)
747839
{
748-
return _mm512_i64gather_pd(index, base, scale);
840+
return set(arr[ind[7]],
841+
arr[ind[6]],
842+
arr[ind[5]],
843+
arr[ind[4]],
844+
arr[ind[3]],
845+
arr[ind[2]],
846+
arr[ind[1]],
847+
arr[ind[0]]);
749848
}
750849
static reg_t loadu(void const *mem)
751850
{
@@ -841,7 +940,6 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm)
841940
template <typename vtype, typename reg_t = typename vtype::reg_t>
842941
X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit(reg_t zmm)
843942
{
844-
845943
// 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
846944
zmm = cmp_merge<vtype>(
847945
zmm,

src/avx512-common-argsort.h

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ static inline int64_t partition_avx512(type_t *arr,
7575

7676
if (right - left == vtype::numlanes) {
7777
argzmm_t argvec = argtype::loadu(arg + left);
78-
reg_t vec = vtype::template i64gather<sizeof(type_t)>(argvec, arr);
78+
reg_t vec = vtype::i64gather(arr, arg + left);
7979
int32_t amount_gt_pivot = partition_vec<vtype>(arg,
8080
left,
8181
left + vtype::numlanes,
@@ -91,11 +91,9 @@ static inline int64_t partition_avx512(type_t *arr,
9191

9292
// first and last vtype::numlanes values are partitioned at the end
9393
argzmm_t argvec_left = argtype::loadu(arg + left);
94-
reg_t vec_left
95-
= vtype::template i64gather<sizeof(type_t)>(argvec_left, arr);
94+
reg_t vec_left = vtype::i64gather(arr, arg + left);
9695
argzmm_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes));
97-
reg_t vec_right
98-
= vtype::template i64gather<sizeof(type_t)>(argvec_right, arr);
96+
reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes));
9997
// store points of the vectors
10098
int64_t r_store = right - vtype::numlanes;
10199
int64_t l_store = left;
@@ -113,11 +111,11 @@ static inline int64_t partition_avx512(type_t *arr,
113111
if ((r_store + vtype::numlanes) - right < left - l_store) {
114112
right -= vtype::numlanes;
115113
arg_vec = argtype::loadu(arg + right);
116-
curr_vec = vtype::template i64gather<sizeof(type_t)>(arg_vec, arr);
114+
curr_vec = vtype::i64gather(arr, arg + right);
117115
}
118116
else {
119117
arg_vec = argtype::loadu(arg + left);
120-
curr_vec = vtype::template i64gather<sizeof(type_t)>(arg_vec, arr);
118+
curr_vec = vtype::i64gather(arr, arg + left);
121119
left += vtype::numlanes;
122120
}
123121
// partition the current vector and save it on both sides of the array
@@ -201,12 +199,11 @@ static inline int64_t partition_avx512_unrolled(type_t *arr,
201199
X86_SIMD_SORT_UNROLL_LOOP(8)
202200
for (int ii = 0; ii < num_unroll; ++ii) {
203201
argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii);
204-
vec_left[ii] = vtype::template i64gather<sizeof(type_t)>(
205-
argvec_left[ii], arr);
202+
vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii);
206203
argvec_right[ii] = argtype::loadu(
207204
arg + (right - vtype::numlanes * (num_unroll - ii)));
208-
vec_right[ii] = vtype::template i64gather<sizeof(type_t)>(
209-
argvec_right[ii], arr);
205+
vec_right[ii] = vtype::i64gather(
206+
arr, arg + (right - vtype::numlanes * (num_unroll - ii)));
210207
}
211208
// store points of the vectors
212209
int64_t r_store = right - vtype::numlanes;
@@ -228,16 +225,16 @@ X86_SIMD_SORT_UNROLL_LOOP(8)
228225
for (int ii = 0; ii < num_unroll; ++ii) {
229226
arg_vec[ii]
230227
= argtype::loadu(arg + right + ii * vtype::numlanes);
231-
curr_vec[ii] = vtype::template i64gather<sizeof(type_t)>(
232-
arg_vec[ii], arr);
228+
curr_vec[ii] = vtype::i64gather(
229+
arr, arg + right + ii * vtype::numlanes);
233230
}
234231
}
235232
else {
236233
X86_SIMD_SORT_UNROLL_LOOP(8)
237234
for (int ii = 0; ii < num_unroll; ++ii) {
238235
arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes);
239-
curr_vec[ii] = vtype::template i64gather<sizeof(type_t)>(
240-
arg_vec[ii], arr);
236+
curr_vec[ii] = vtype::i64gather(
237+
arr, arg + left + ii * vtype::numlanes);
241238
}
242239
left += num_unroll * vtype::numlanes;
243240
}

0 commit comments

Comments
 (0)