From f2ca68b5ce701f61634a691389529c0cecc63610 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 21 Aug 2023 12:43:03 -0700 Subject: [PATCH 1/3] Use scalar emulation of gather instruction for arg methods --- src/avx512-64bit-argsort.hpp | 27 +++---- src/avx512-64bit-common.h | 142 +++++++++++++++++++++++++++++------ src/avx512-common-argsort.h | 27 +++---- src/avx512-common-qsort.h | 56 ++++++-------- 4 files changed, 169 insertions(+), 83 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index e8254ce1..7c6a34c1 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -85,7 +85,7 @@ X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N) typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; argzmm_t argzmm1 = argtype::loadu(arg); argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); - reg_t arrzmm1 = vtype::template i64gather(argzmm1, arr); + reg_t arrzmm1 = vtype::i64gather(arr, arg); reg_t arrzmm2 = vtype::template mask_i64gather( vtype::zmm_max(), load_mask, argzmm2, arr); arrzmm1 = sort_zmm_64bit(arrzmm1, argzmm1); @@ -111,7 +111,7 @@ X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) X86_SIMD_SORT_UNROLL_LOOP(2) for (int ii = 0; ii < 2; ++ii) { argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); + arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); } @@ -154,7 +154,7 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) X86_SIMD_SORT_UNROLL_LOOP(4) for (int ii = 0; ii < 4; ++ii) { argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); + arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); } @@ -206,7 +206,7 @@ X86_SIMD_SORT_UNROLL_LOOP(4) //X86_SIMD_SORT_UNROLL_LOOP(8) // for (int ii = 0; ii < 8; ++ii) { // argzmm[ii] = argtype::loadu(arg + 8*ii); -// arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); +// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr); // arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); // } // @@ -257,17 +257,14 @@ type_t get_pivot_64bit(type_t *arr, // median of 8 int64_t size = (right - left) / 8; using reg_t = typename vtype::reg_t; - // TODO: Use gather here too: - __m512i rand_index = _mm512_set_epi64(arg[left + size], - arg[left + 2 * size], - arg[left + 3 * size], - arg[left + 4 * size], - arg[left + 5 * size], - arg[left + 6 * size], - arg[left + 7 * size], - arg[left + 8 * size]); - reg_t rand_vec - = vtype::template i64gather(rand_index, arr); + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]], + arr[arg[left + 5 * size]], + arr[arg[left + 6 * size]], + arr[arg[left + 7 * size]], + arr[arg[left + 8 * size]]); // pivot will never be a nan, since there are no nan's! reg_t sort = sort_zmm_64bit(rand_vec); return ((type_t *)&sort)[4]; diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 0353507e..bbbd440b 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -45,12 +45,22 @@ struct ymm_vector { { return _mm256_set1_ps(type_max()); } - static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm256_set_ps(v1, v2, v3, v4, v5, v6, v7, v8); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _kxor_mask8(x, y); @@ -86,10 +96,16 @@ struct ymm_vector { { return _mm512_mask_i64gather_ps(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_ps(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -195,6 +211,17 @@ struct ymm_vector { { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _kxor_mask8(x, y); @@ -221,10 +248,16 @@ struct ymm_vector { { return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_epi32(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -324,6 +357,17 @@ struct ymm_vector { { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _kxor_mask8(x, y); @@ -350,10 +394,16 @@ struct ymm_vector { { return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_epi32(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -456,6 +506,17 @@ struct zmm_vector { { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _kxor_mask8(x, y); @@ -482,10 +543,16 @@ struct zmm_vector { { return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_epi64(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -589,16 +656,33 @@ struct zmm_vector { { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } template static reg_t mask_i64gather(reg_t src, opmask_t mask, __m512i index, void const *base) { return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_epi64(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static opmask_t knot_opmask(opmask_t x) { @@ -704,13 +788,22 @@ struct zmm_vector { { return _mm512_set1_pd(type_max()); } - static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } - + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); + } static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm512_maskz_loadu_pd(mask, mem); @@ -742,10 +835,16 @@ struct zmm_vector { { return _mm512_mask_i64gather_pd(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_pd(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -841,7 +940,6 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm) template X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit(reg_t zmm) { - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 zmm = cmp_merge( zmm, diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 015a6bd7..5c73aede 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -75,7 +75,7 @@ static inline int64_t partition_avx512(type_t *arr, if (right - left == vtype::numlanes) { argzmm_t argvec = argtype::loadu(arg + left); - reg_t vec = vtype::template i64gather(argvec, arr); + reg_t vec = vtype::i64gather(arr, arg + left); int32_t amount_gt_pivot = partition_vec(arg, left, left + vtype::numlanes, @@ -91,11 +91,9 @@ static inline int64_t partition_avx512(type_t *arr, // first and last vtype::numlanes values are partitioned at the end argzmm_t argvec_left = argtype::loadu(arg + left); - reg_t vec_left - = vtype::template i64gather(argvec_left, arr); + reg_t vec_left = vtype::i64gather(arr, arg + left); argzmm_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - reg_t vec_right - = vtype::template i64gather(argvec_right, arr); + reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); // store points of the vectors int64_t r_store = right - vtype::numlanes; int64_t l_store = left; @@ -113,11 +111,11 @@ static inline int64_t partition_avx512(type_t *arr, if ((r_store + vtype::numlanes) - right < left - l_store) { right -= vtype::numlanes; arg_vec = argtype::loadu(arg + right); - curr_vec = vtype::template i64gather(arg_vec, arr); + curr_vec = vtype::i64gather(arr, arg + right); } else { arg_vec = argtype::loadu(arg + left); - curr_vec = vtype::template i64gather(arg_vec, arr); + curr_vec = vtype::i64gather(arr, arg + left); left += vtype::numlanes; } // partition the current vector and save it on both sides of the array @@ -201,12 +199,11 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); - vec_left[ii] = vtype::template i64gather( - argvec_left[ii], arr); + vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); argvec_right[ii] = argtype::loadu( arg + (right - vtype::numlanes * (num_unroll - ii))); - vec_right[ii] = vtype::template i64gather( - argvec_right[ii], arr); + vec_right[ii] = vtype::i64gather( + arr, arg + (right - vtype::numlanes * (num_unroll - ii))); } // store points of the vectors int64_t r_store = right - vtype::numlanes; @@ -228,16 +225,16 @@ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { arg_vec[ii] = argtype::loadu(arg + right + ii * vtype::numlanes); - curr_vec[ii] = vtype::template i64gather( - arg_vec[ii], arr); + curr_vec[ii] = vtype::i64gather( + arr, arg + right + ii * vtype::numlanes); } } else { X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); - curr_vec[ii] = vtype::template i64gather( - arg_vec[ii], arr); + curr_vec[ii] = vtype::i64gather( + arr, arg + left + ii * vtype::numlanes); } left += num_unroll * vtype::numlanes; } diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 70b63af3..4ebd9475 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -758,28 +758,23 @@ X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, // median of 16 int64_t size = (right - left) / 16; using zmm_t = typename vtype::reg_t; - using ymm_t = typename vtype::halfreg_t; - __m512i rand_index1 = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - __m512i rand_index2 = _mm512_set_epi64(left + 9 * size, - left + 10 * size, - left + 11 * size, - left + 12 * size, - left + 13 * size, - left + 14 * size, - left + 15 * size, - left + 16 * size); - ymm_t rand_vec1 - = vtype::template i64gather(rand_index1, arr); - ymm_t rand_vec2 - = vtype::template i64gather(rand_index2, arr); - zmm_t rand_vec = vtype::merge(rand_vec1, rand_vec2); + type_t vec_arr[16] = {arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size], + arr[left + 9 * size], + arr[left + 10 * size], + arr[left + 11 * size], + arr[left + 12 * size], + arr[left + 13 * size], + arr[left + 14 * size], + arr[left + 15 * size], + arr[left + 16 * size]}; + zmm_t rand_vec = vtype::loadu(vec_arr); zmm_t sort = vtype::sort_vec(rand_vec); // pivot will never be a nan, since there are no nan's! return ((type_t *)&sort)[8]; @@ -793,15 +788,14 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, // median of 8 int64_t size = (right - left) / 8; using zmm_t = typename vtype::reg_t; - __m512i rand_index = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + zmm_t rand_vec = vtype::set(arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size]); // pivot will never be a nan, since there are no nan's! zmm_t sort = vtype::sort_vec(rand_vec); return ((type_t *)&sort)[4]; From 4767a4c39d33a3748a318e0d7b2d183630aace9e Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 7 Sep 2023 13:35:39 -0700 Subject: [PATCH 2/3] Replace zmm_t with reg_t --- src/avx512-64bit-argsort.hpp | 12 +- src/avx512-64bit-keyvalue-networks.hpp | 360 ++++++++++++------------- src/avx512-common-argsort.h | 16 +- src/avx512-common-qsort.h | 80 +++--- 4 files changed, 234 insertions(+), 234 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 7c6a34c1..c9c5e961 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -67,7 +67,7 @@ X86_SIMD_SORT_INLINE void argsort_8_64bit(type_t *arr, int64_t *arg, int32_t N) { using reg_t = typename vtype::reg_t; typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - argzmm_t argzmm = argtype::maskz_loadu(load_mask, arg); + argreg_t argzmm = argtype::maskz_loadu(load_mask, arg); reg_t arrzmm = vtype::template mask_i64gather( vtype::zmm_max(), load_mask, argzmm, arr); arrzmm = sort_zmm_64bit(arrzmm, argzmm); @@ -83,8 +83,8 @@ X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N) } using reg_t = typename vtype::reg_t; typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - argzmm_t argzmm1 = argtype::loadu(arg); - argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); + argreg_t argzmm1 = argtype::loadu(arg); + argreg_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); reg_t arrzmm1 = vtype::i64gather(arr, arg); reg_t arrzmm2 = vtype::template mask_i64gather( vtype::zmm_max(), load_mask, argzmm2, arr); @@ -106,7 +106,7 @@ X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) using reg_t = typename vtype::reg_t; using opmask_t = typename vtype::opmask_t; reg_t arrzmm[4]; - argzmm_t argzmm[4]; + argreg_t argzmm[4]; X86_SIMD_SORT_UNROLL_LOOP(2) for (int ii = 0; ii < 2; ++ii) { @@ -149,7 +149,7 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) using reg_t = typename vtype::reg_t; using opmask_t = typename vtype::opmask_t; reg_t arrzmm[8]; - argzmm_t argzmm[8]; + argreg_t argzmm[8]; X86_SIMD_SORT_UNROLL_LOOP(4) for (int ii = 0; ii < 4; ++ii) { @@ -201,7 +201,7 @@ X86_SIMD_SORT_UNROLL_LOOP(4) // using reg_t = typename vtype::reg_t; // using opmask_t = typename vtype::opmask_t; // reg_t arrzmm[16]; -// argzmm_t argzmm[16]; +// argreg_t argzmm[16]; // //X86_SIMD_SORT_UNROLL_LOOP(8) // for (int ii = 0; ii < 8; ++ii) { diff --git a/src/avx512-64bit-keyvalue-networks.hpp b/src/avx512-64bit-keyvalue-networks.hpp index bae8fa6d..62da3f77 100644 --- a/src/avx512-64bit-keyvalue-networks.hpp +++ b/src/avx512-64bit-keyvalue-networks.hpp @@ -128,45 +128,45 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(reg_t *key_zmm, index_type index_zmm2r = vtype2::permutexvar(rev_index2, index_zmm[2]); index_type index_zmm3r = vtype2::permutexvar(rev_index2, index_zmm[3]); - reg_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm3r); - reg_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm2r); + reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm3r); + reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm2r); reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); - typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]); - typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]); + typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); + typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); - index_type index_zmm_t1 + index_type index_reg_t1 = vtype2::mask_mov(index_zmm3r, movmask1, index_zmm[0]); index_type index_zmm_m1 = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm3r); - index_type index_zmm_t2 + index_type index_reg_t2 = vtype2::mask_mov(index_zmm2r, movmask2, index_zmm[1]); index_type index_zmm_m2 = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r); // 2) Recursive half clearer: 16 - reg_t key_zmm_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_zmm_t4 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_zmm_t3 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_zmm_t4 = vtype2::permutexvar(rev_index2, index_zmm_m1); + reg_t key_reg_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); + reg_t key_reg_t4 = vtype1::permutexvar(rev_index1, key_zmm_m1); + index_type index_reg_t3 = vtype2::permutexvar(rev_index2, index_zmm_m2); + index_type index_reg_t4 = vtype2::permutexvar(rev_index2, index_zmm_m1); - reg_t key_zmm0 = vtype1::min(key_zmm_t1, key_zmm_t2); - reg_t key_zmm1 = vtype1::max(key_zmm_t1, key_zmm_t2); - reg_t key_zmm2 = vtype1::min(key_zmm_t3, key_zmm_t4); - reg_t key_zmm3 = vtype1::max(key_zmm_t3, key_zmm_t4); + reg_t key_zmm0 = vtype1::min(key_reg_t1, key_reg_t2); + reg_t key_zmm1 = vtype1::max(key_reg_t1, key_reg_t2); + reg_t key_zmm2 = vtype1::min(key_reg_t3, key_reg_t4); + reg_t key_zmm3 = vtype1::max(key_reg_t3, key_reg_t4); - movmask1 = vtype1::eq(key_zmm0, key_zmm_t1); - movmask2 = vtype1::eq(key_zmm2, key_zmm_t3); + movmask1 = vtype1::eq(key_zmm0, key_reg_t1); + movmask2 = vtype1::eq(key_zmm2, key_reg_t3); index_type index_zmm0 - = vtype2::mask_mov(index_zmm_t2, movmask1, index_zmm_t1); + = vtype2::mask_mov(index_reg_t2, movmask1, index_reg_t1); index_type index_zmm1 - = vtype2::mask_mov(index_zmm_t1, movmask1, index_zmm_t2); + = vtype2::mask_mov(index_reg_t1, movmask1, index_reg_t2); index_type index_zmm2 - = vtype2::mask_mov(index_zmm_t4, movmask2, index_zmm_t3); + = vtype2::mask_mov(index_reg_t4, movmask2, index_reg_t3); index_type index_zmm3 - = vtype2::mask_mov(index_zmm_t3, movmask2, index_zmm_t4); + = vtype2::mask_mov(index_reg_t3, movmask2, index_reg_t4); key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); @@ -197,80 +197,80 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(reg_t *key_zmm, index_type index_zmm6r = vtype2::permutexvar(rev_index2, index_zmm[6]); index_type index_zmm7r = vtype2::permutexvar(rev_index2, index_zmm[7]); - reg_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm7r); - reg_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm6r); - reg_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm5r); - reg_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm4r); + reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm7r); + reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm6r); + reg_t key_reg_t3 = vtype1::min(key_zmm[2], key_zmm5r); + reg_t key_reg_t4 = vtype1::min(key_zmm[3], key_zmm4r); reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); - typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]); - typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]); - typename vtype1::opmask_t movmask3 = vtype1::eq(key_zmm_t3, key_zmm[2]); - typename vtype1::opmask_t movmask4 = vtype1::eq(key_zmm_t4, key_zmm[3]); + typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); + typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); + typename vtype1::opmask_t movmask3 = vtype1::eq(key_reg_t3, key_zmm[2]); + typename vtype1::opmask_t movmask4 = vtype1::eq(key_reg_t4, key_zmm[3]); - index_type index_zmm_t1 + index_type index_reg_t1 = vtype2::mask_mov(index_zmm7r, movmask1, index_zmm[0]); index_type index_zmm_m1 = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm7r); - index_type index_zmm_t2 + index_type index_reg_t2 = vtype2::mask_mov(index_zmm6r, movmask2, index_zmm[1]); index_type index_zmm_m2 = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm6r); - index_type index_zmm_t3 + index_type index_reg_t3 = vtype2::mask_mov(index_zmm5r, movmask3, index_zmm[2]); index_type index_zmm_m3 = vtype2::mask_mov(index_zmm[2], movmask3, index_zmm5r); - index_type index_zmm_t4 + index_type index_reg_t4 = vtype2::mask_mov(index_zmm4r, movmask4, index_zmm[3]); index_type index_zmm_m4 = vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r); - reg_t key_zmm_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); - reg_t key_zmm_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); - reg_t key_zmm_t7 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_zmm_t8 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_zmm_t5 = vtype2::permutexvar(rev_index2, index_zmm_m4); - index_type index_zmm_t6 = vtype2::permutexvar(rev_index2, index_zmm_m3); - index_type index_zmm_t7 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_zmm_t8 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + reg_t key_reg_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); + reg_t key_reg_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); + reg_t key_reg_t7 = vtype1::permutexvar(rev_index1, key_zmm_m2); + reg_t key_reg_t8 = vtype1::permutexvar(rev_index1, key_zmm_m1); + index_type index_reg_t5 = vtype2::permutexvar(rev_index2, index_zmm_m4); + index_type index_reg_t6 = vtype2::permutexvar(rev_index2, index_zmm_m3); + index_type index_reg_t7 = vtype2::permutexvar(rev_index2, index_zmm_m2); + index_type index_reg_t8 = vtype2::permutexvar(rev_index2, index_zmm_m1); + + COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); + COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); + COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); + COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); + COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); + COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); + COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); + COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); key_zmm[0] - = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); key_zmm[1] - = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); key_zmm[2] - = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); key_zmm[3] - = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); key_zmm[4] - = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); key_zmm[5] - = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); key_zmm[6] - = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); key_zmm[7] - = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); - - index_zmm[0] = index_zmm_t1; - index_zmm[1] = index_zmm_t2; - index_zmm[2] = index_zmm_t3; - index_zmm[3] = index_zmm_t4; - index_zmm[4] = index_zmm_t5; - index_zmm[5] = index_zmm_t6; - index_zmm[6] = index_zmm_t7; - index_zmm[7] = index_zmm_t8; + = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); + + index_zmm[0] = index_reg_t1; + index_zmm[1] = index_reg_t2; + index_zmm[2] = index_reg_t3; + index_zmm[3] = index_reg_t4; + index_zmm[4] = index_reg_t5; + index_zmm[5] = index_reg_t6; + index_zmm[6] = index_reg_t7; + index_zmm[7] = index_reg_t8; } template (key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); - COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); - COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); - COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); + index_zmm[7], vtype1::eq(key_reg_t8, key_zmm[7]), index_zmm8r); + + reg_t key_reg_t9 = vtype1::permutexvar(rev_index1, key_zmm_m8); + reg_t key_reg_t10 = vtype1::permutexvar(rev_index1, key_zmm_m7); + reg_t key_reg_t11 = vtype1::permutexvar(rev_index1, key_zmm_m6); + reg_t key_reg_t12 = vtype1::permutexvar(rev_index1, key_zmm_m5); + reg_t key_reg_t13 = vtype1::permutexvar(rev_index1, key_zmm_m4); + reg_t key_reg_t14 = vtype1::permutexvar(rev_index1, key_zmm_m3); + reg_t key_reg_t15 = vtype1::permutexvar(rev_index1, key_zmm_m2); + reg_t key_reg_t16 = vtype1::permutexvar(rev_index1, key_zmm_m1); + index_type index_reg_t9 = vtype2::permutexvar(rev_index2, index_zmm_m8); + index_type index_reg_t10 = vtype2::permutexvar(rev_index2, index_zmm_m7); + index_type index_reg_t11 = vtype2::permutexvar(rev_index2, index_zmm_m6); + index_type index_reg_t12 = vtype2::permutexvar(rev_index2, index_zmm_m5); + index_type index_reg_t13 = vtype2::permutexvar(rev_index2, index_zmm_m4); + index_type index_reg_t14 = vtype2::permutexvar(rev_index2, index_zmm_m3); + index_type index_reg_t15 = vtype2::permutexvar(rev_index2, index_zmm_m2); + index_type index_reg_t16 = vtype2::permutexvar(rev_index2, index_zmm_m1); + + COEX(key_reg_t1, key_reg_t5, index_reg_t1, index_reg_t5); + COEX(key_reg_t2, key_reg_t6, index_reg_t2, index_reg_t6); + COEX(key_reg_t3, key_reg_t7, index_reg_t3, index_reg_t7); + COEX(key_reg_t4, key_reg_t8, index_reg_t4, index_reg_t8); + COEX(key_reg_t9, key_reg_t13, index_reg_t9, index_reg_t13); COEX( - key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); + key_reg_t10, key_reg_t14, index_reg_t10, index_reg_t14); COEX( - key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); + key_reg_t11, key_reg_t15, index_reg_t11, index_reg_t15); COEX( - key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); + key_reg_t12, key_reg_t16, index_reg_t12, index_reg_t16); - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); + COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); + COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); + COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); + COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); + COEX(key_reg_t9, key_reg_t11, index_reg_t9, index_reg_t11); COEX( - key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); + key_reg_t10, key_reg_t12, index_reg_t10, index_reg_t12); COEX( - key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); + key_reg_t13, key_reg_t15, index_reg_t13, index_reg_t15); COEX( - key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); + key_reg_t14, key_reg_t16, index_reg_t14, index_reg_t16); - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); + COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); + COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); + COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); + COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); + COEX(key_reg_t9, key_reg_t10, index_reg_t9, index_reg_t10); COEX( - key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); + key_reg_t11, key_reg_t12, index_reg_t11, index_reg_t12); COEX( - key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); + key_reg_t13, key_reg_t14, index_reg_t13, index_reg_t14); COEX( - key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); + key_reg_t15, key_reg_t16, index_reg_t15, index_reg_t16); // key_zmm[0] - = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); key_zmm[1] - = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); key_zmm[2] - = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); key_zmm[3] - = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); key_zmm[4] - = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); key_zmm[5] - = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); key_zmm[6] - = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); key_zmm[7] - = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); key_zmm[8] - = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, - index_zmm_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, - index_zmm_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, - index_zmm_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, - index_zmm_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, - index_zmm_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, - index_zmm_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, - index_zmm_t16); - - index_zmm[0] = index_zmm_t1; - index_zmm[1] = index_zmm_t2; - index_zmm[2] = index_zmm_t3; - index_zmm[3] = index_zmm_t4; - index_zmm[4] = index_zmm_t5; - index_zmm[5] = index_zmm_t6; - index_zmm[6] = index_zmm_t7; - index_zmm[7] = index_zmm_t8; - index_zmm[8] = index_zmm_t9; - index_zmm[9] = index_zmm_t10; - index_zmm[10] = index_zmm_t11; - index_zmm[11] = index_zmm_t12; - index_zmm[12] = index_zmm_t13; - index_zmm[13] = index_zmm_t14; - index_zmm[14] = index_zmm_t15; - index_zmm[15] = index_zmm_t16; + = bitonic_merge_zmm_64bit(key_reg_t9, index_reg_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_reg_t10, + index_reg_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_reg_t11, + index_reg_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_reg_t12, + index_reg_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_reg_t13, + index_reg_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_reg_t14, + index_reg_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_reg_t15, + index_reg_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_reg_t16, + index_reg_t16); + + index_zmm[0] = index_reg_t1; + index_zmm[1] = index_reg_t2; + index_zmm[2] = index_reg_t3; + index_zmm[3] = index_reg_t4; + index_zmm[4] = index_reg_t5; + index_zmm[5] = index_reg_t6; + index_zmm[6] = index_reg_t7; + index_zmm[7] = index_reg_t8; + index_zmm[8] = index_reg_t9; + index_zmm[9] = index_reg_t10; + index_zmm[10] = index_reg_t11; + index_zmm[11] = index_reg_t12; + index_zmm[12] = index_reg_t13; + index_zmm[13] = index_reg_t14; + index_zmm[14] = index_reg_t15; + index_zmm[15] = index_reg_t16; } diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 5c73aede..375afc0b 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -13,7 +13,7 @@ #include using argtype = zmm_vector; -using argzmm_t = typename argtype::reg_t; +using argreg_t = typename argtype::reg_t; /* * Parition one ZMM register based on the pivot and returns the index of the @@ -23,7 +23,7 @@ template static inline int32_t partition_vec(type_t *arg, int64_t left, int64_t right, - const argzmm_t arg_vec, + const argreg_t arg_vec, const reg_t curr_vec, const reg_t pivot_vec, reg_t *smallest_vec, @@ -74,7 +74,7 @@ static inline int64_t partition_avx512(type_t *arr, reg_t max_vec = vtype::set1(*biggest); if (right - left == vtype::numlanes) { - argzmm_t argvec = argtype::loadu(arg + left); + argreg_t argvec = argtype::loadu(arg + left); reg_t vec = vtype::i64gather(arr, arg + left); int32_t amount_gt_pivot = partition_vec(arg, left, @@ -90,9 +90,9 @@ static inline int64_t partition_avx512(type_t *arr, } // first and last vtype::numlanes values are partitioned at the end - argzmm_t argvec_left = argtype::loadu(arg + left); + argreg_t argvec_left = argtype::loadu(arg + left); reg_t vec_left = vtype::i64gather(arr, arg + left); - argzmm_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); + argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); // store points of the vectors int64_t r_store = right - vtype::numlanes; @@ -101,7 +101,7 @@ static inline int64_t partition_avx512(type_t *arr, left += vtype::numlanes; right -= vtype::numlanes; while (right - left != 0) { - argzmm_t arg_vec; + argreg_t arg_vec; reg_t curr_vec; /* * if fewer elements are stored on the right side of the array, @@ -195,7 +195,7 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, // first and last vtype::numlanes values are partitioned at the end reg_t vec_left[num_unroll], vec_right[num_unroll]; - argzmm_t argvec_left[num_unroll], argvec_right[num_unroll]; + argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); @@ -212,7 +212,7 @@ X86_SIMD_SORT_UNROLL_LOOP(8) left += num_unroll * vtype::numlanes; right -= num_unroll * vtype::numlanes; while (right - left != 0) { - argzmm_t arg_vec[num_unroll]; + argreg_t arg_vec[num_unroll]; reg_t curr_vec[num_unroll]; /* * if fewer elements are stored on the right side of the array, diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 4ebd9475..45430e0e 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -484,16 +484,16 @@ X86_SIMD_SORT_UNROLL_LOOP(8) template -static void COEX(zmm_t1 &key1, zmm_t1 &key2, zmm_t2 &index1, zmm_t2 &index2) + typename reg_t1 = typename vtype1::reg_t, + typename reg_t2 = typename vtype2::reg_t> +static void COEX(reg_t1 &key1, reg_t1 &key2, reg_t2 &index1, reg_t2 &index2) { - zmm_t1 key_t1 = vtype1::min(key1, key2); - zmm_t1 key_t2 = vtype1::max(key1, key2); + reg_t1 key_t1 = vtype1::min(key1, key2); + reg_t1 key_t2 = vtype1::max(key1, key2); - zmm_t2 index_t1 + reg_t2 index_t1 = vtype2::mask_mov(index2, vtype1::eq(key_t1, key1), index1); - zmm_t2 index_t2 + reg_t2 index_t2 = vtype2::mask_mov(index1, vtype1::eq(key_t1, key1), index2); key1 = key_t1; @@ -503,16 +503,16 @@ static void COEX(zmm_t1 &key1, zmm_t1 &key2, zmm_t2 &index1, zmm_t2 &index2) } template -static inline zmm_t1 cmp_merge(zmm_t1 in1, - zmm_t1 in2, - zmm_t2 &indexes1, - zmm_t2 indexes2, +static inline reg_t1 cmp_merge(reg_t1 in1, + reg_t1 in2, + reg_t2 &indexes1, + reg_t2 indexes2, opmask_t mask) { - zmm_t1 tmp_keys = cmp_merge(in1, in2, mask); + reg_t1 tmp_keys = cmp_merge(in1, in2, mask); indexes1 = vtype2::mask_mov(indexes2, vtype1::eq(tmp_keys, in1), indexes1); return tmp_keys; // 0 -> min, 1 -> max } @@ -525,17 +525,17 @@ template + typename reg_t1 = typename vtype1::reg_t, + typename reg_t2 = typename vtype2::reg_t> static inline int32_t partition_vec(type_t1 *keys, type_t2 *indexes, int64_t left, int64_t right, - const zmm_t1 keys_vec, - const zmm_t2 indexes_vec, - const zmm_t1 pivot_vec, - zmm_t1 *smallest_vec, - zmm_t1 *biggest_vec) + const reg_t1 keys_vec, + const reg_t2 indexes_vec, + const reg_t1 pivot_vec, + reg_t1 *smallest_vec, + reg_t1 *biggest_vec) { /* which elements are larger than the pivot */ typename vtype1::opmask_t gt_mask = vtype1::ge(keys_vec, pivot_vec); @@ -560,8 +560,8 @@ template + typename reg_t1 = typename vtype1::reg_t, + typename reg_t2 = typename vtype2::reg_t> static inline int64_t partition_avx512(type_t1 *keys, type_t2 *indexes, int64_t left, @@ -587,15 +587,15 @@ static inline int64_t partition_avx512(type_t1 *keys, if (left == right) return left; /* less than vtype1::numlanes elements in the array */ - zmm_t1 pivot_vec = vtype1::set1(pivot); - zmm_t1 min_vec = vtype1::set1(*smallest); - zmm_t1 max_vec = vtype1::set1(*biggest); + reg_t1 pivot_vec = vtype1::set1(pivot); + reg_t1 min_vec = vtype1::set1(*smallest); + reg_t1 max_vec = vtype1::set1(*biggest); if (right - left == vtype1::numlanes) { - zmm_t1 keys_vec = vtype1::loadu(keys + left); + reg_t1 keys_vec = vtype1::loadu(keys + left); int32_t amount_gt_pivot; - zmm_t2 indexes_vec = vtype2::loadu(indexes + left); + reg_t2 indexes_vec = vtype2::loadu(indexes + left); amount_gt_pivot = partition_vec(keys, indexes, left, @@ -612,10 +612,10 @@ static inline int64_t partition_avx512(type_t1 *keys, } // first and last vtype1::numlanes values are partitioned at the end - zmm_t1 keys_vec_left = vtype1::loadu(keys + left); - zmm_t1 keys_vec_right = vtype1::loadu(keys + (right - vtype1::numlanes)); - zmm_t2 indexes_vec_left; - zmm_t2 indexes_vec_right; + reg_t1 keys_vec_left = vtype1::loadu(keys + left); + reg_t1 keys_vec_right = vtype1::loadu(keys + (right - vtype1::numlanes)); + reg_t2 indexes_vec_left; + reg_t2 indexes_vec_right; indexes_vec_left = vtype2::loadu(indexes + left); indexes_vec_right = vtype2::loadu(indexes + (right - vtype1::numlanes)); @@ -626,8 +626,8 @@ static inline int64_t partition_avx512(type_t1 *keys, left += vtype1::numlanes; right -= vtype1::numlanes; while (right - left != 0) { - zmm_t1 keys_vec; - zmm_t2 indexes_vec; + reg_t1 keys_vec; + reg_t2 indexes_vec; /* * if fewer elements are stored on the right side of the array, * then next elements are loaded from the right side, @@ -757,7 +757,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, { // median of 16 int64_t size = (right - left) / 16; - using zmm_t = typename vtype::reg_t; + using reg_t = typename vtype::reg_t; type_t vec_arr[16] = {arr[left + size], arr[left + 2 * size], arr[left + 3 * size], @@ -774,8 +774,8 @@ X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, arr[left + 14 * size], arr[left + 15 * size], arr[left + 16 * size]}; - zmm_t rand_vec = vtype::loadu(vec_arr); - zmm_t sort = vtype::sort_vec(rand_vec); + reg_t rand_vec = vtype::loadu(vec_arr); + reg_t sort = vtype::sort_vec(rand_vec); // pivot will never be a nan, since there are no nan's! return ((type_t *)&sort)[8]; } @@ -787,8 +787,8 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, { // median of 8 int64_t size = (right - left) / 8; - using zmm_t = typename vtype::reg_t; - zmm_t rand_vec = vtype::set(arr[left + size], + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[left + size], arr[left + 2 * size], arr[left + 3 * size], arr[left + 4 * size], @@ -797,7 +797,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, arr[left + 7 * size], arr[left + 8 * size]); // pivot will never be a nan, since there are no nan's! - zmm_t sort = vtype::sort_vec(rand_vec); + reg_t sort = vtype::sort_vec(rand_vec); return ((type_t *)&sort)[4]; } From 22c2f02c0ccf4f56432ecbe3383113bfa2deee67 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 7 Sep 2023 13:37:01 -0700 Subject: [PATCH 3/3] Replace zmmi_t with regi_t --- src/avx512-64bit-common.h | 32 +++++++++++++------------- src/avx512-64bit-keyvalue-networks.hpp | 20 ++++++++-------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index bbbd440b..3227e071 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -29,7 +29,7 @@ template <> struct ymm_vector { using type_t = float; using reg_t = __m256; - using zmmi_t = __m256i; + using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -45,7 +45,7 @@ struct ymm_vector { { return _mm256_set1_ps(type_max()); } - static zmmi_t + static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); @@ -189,7 +189,7 @@ template <> struct ymm_vector { using type_t = uint32_t; using reg_t = __m256i; - using zmmi_t = __m256i; + using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -206,7 +206,7 @@ struct ymm_vector { return _mm256_set1_epi32(type_max()); } - static zmmi_t + static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); @@ -335,7 +335,7 @@ template <> struct ymm_vector { using type_t = int32_t; using reg_t = __m256i; - using zmmi_t = __m256i; + using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -352,7 +352,7 @@ struct ymm_vector { return _mm256_set1_epi32(type_max()); } // TODO: this should broadcast bits as is? - static zmmi_t + static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); @@ -481,7 +481,7 @@ template <> struct zmm_vector { using type_t = int64_t; using reg_t = __m512i; - using zmmi_t = __m512i; + using regi_t = __m512i; using halfreg_t = __m512i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -501,7 +501,7 @@ struct zmm_vector { return _mm512_set1_epi64(type_max()); } // TODO: this should broadcast bits as is? - static zmmi_t + static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); @@ -615,7 +615,7 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const zmmi_t rev_index = seti(NETWORK_64BIT_2); + const regi_t rev_index = seti(NETWORK_64BIT_2); return permutexvar(rev_index, zmm); } static reg_t bitonic_merge(reg_t x) @@ -631,7 +631,7 @@ template <> struct zmm_vector { using type_t = uint64_t; using reg_t = __m512i; - using zmmi_t = __m512i; + using regi_t = __m512i; using halfreg_t = __m512i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -651,7 +651,7 @@ struct zmm_vector { return _mm512_set1_epi64(type_max()); } - static zmmi_t + static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); @@ -753,7 +753,7 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const zmmi_t rev_index = seti(NETWORK_64BIT_2); + const regi_t rev_index = seti(NETWORK_64BIT_2); return permutexvar(rev_index, zmm); } static reg_t bitonic_merge(reg_t x) @@ -769,7 +769,7 @@ template <> struct zmm_vector { using type_t = double; using reg_t = __m512d; - using zmmi_t = __m512i; + using regi_t = __m512i; using halfreg_t = __m512d; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -788,7 +788,7 @@ struct zmm_vector { { return _mm512_set1_pd(type_max()); } - static zmmi_t + static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); @@ -901,7 +901,7 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const zmmi_t rev_index = seti(NETWORK_64BIT_2); + const regi_t rev_index = seti(NETWORK_64BIT_2); return permutexvar(rev_index, zmm); } static reg_t bitonic_merge(reg_t x) @@ -921,7 +921,7 @@ struct zmm_vector { template X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm) { - const typename vtype::zmmi_t rev_index = vtype::seti(NETWORK_64BIT_2); + const typename vtype::regi_t rev_index = vtype::seti(NETWORK_64BIT_2); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); zmm = cmp_merge( diff --git a/src/avx512-64bit-keyvalue-networks.hpp b/src/avx512-64bit-keyvalue-networks.hpp index 62da3f77..e9577b79 100644 --- a/src/avx512-64bit-keyvalue-networks.hpp +++ b/src/avx512-64bit-keyvalue-networks.hpp @@ -5,8 +5,8 @@ template X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t key_zmm, index_type &index_zmm) { - const typename vtype1::zmmi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::zmmi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); + const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); + const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); key_zmm = cmp_merge( key_zmm, vtype1::template shuffle(key_zmm), @@ -87,8 +87,8 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(reg_t &key_zmm1, index_type &index_zmm1, index_type &index_zmm2) { - const typename vtype1::zmmi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::zmmi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); + const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); + const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); // 1) First step of a merging network: coex of zmm1 and zmm2 reversed key_zmm2 = vtype1::permutexvar(rev_index1, key_zmm2); index_zmm2 = vtype2::permutexvar(rev_index2, index_zmm2); @@ -120,8 +120,8 @@ template