diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index 6ffddbde..fd7f92af 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -11,15 +11,6 @@ #include "xss-common-qsort.h" #include "avx2-emu-funcs.hpp" -/* - * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ymm 3, 2, 1, 0 -#define NETWORK_64BIT_R 0, 1, 2, 3 -#define NETWORK_64BIT_1 1, 0, 3, 2 - /* * Assumes ymm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index c4084c68..799303d8 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -65,24 +65,15 @@ std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) }); } -/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of - * undefined template 'zmm_vector'*/ -#ifdef __APPLE__ -using argtype = typename std::conditional, - zmm_vector>::type; -#else -using argtype = typename std::conditional, - zmm_vector>::type; -#endif -using argreg_t = typename argtype::reg_t; - /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. */ -template +template X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, arrsize_t left, arrsize_t right, @@ -107,7 +98,11 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, * Parition an array based on the pivot and returns the index of the * last element that is less than equal to the pivot. */ -template +template X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, arrsize_t *arg, arrsize_t left, @@ -131,7 +126,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, if (left == right) return left; /* less than vtype::numlanes elements in the array */ - using reg_t = typename vtype::reg_t; reg_t pivot_vec = vtype::set1(pivot); reg_t min_vec = vtype::set1(*smallest); reg_t max_vec = vtype::set1(*biggest); @@ -139,14 +133,15 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, if (right - left == vtype::numlanes) { argreg_t argvec = argtype::loadu(arg + left); reg_t vec = vtype::i64gather(arr, arg + left); - int32_t amount_gt_pivot = partition_vec(arg, - left, - left + vtype::numlanes, - argvec, - vec, - pivot_vec, - &min_vec, - &max_vec); + int32_t amount_gt_pivot + = partition_vec(arg, + left, + left + vtype::numlanes, + argvec, + vec, + pivot_vec, + &min_vec, + &max_vec); *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); return left + (vtype::numlanes - amount_gt_pivot); @@ -183,37 +178,38 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, } // partition the current vector and save it on both sides of the array int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec, - curr_vec, - pivot_vec, - &min_vec, - &max_vec); + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec, + curr_vec, + pivot_vec, + &min_vec, + &max_vec); ; r_store -= amount_gt_pivot; l_store += (vtype::numlanes - amount_gt_pivot); } /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left, - vec_left, - pivot_vec, - &min_vec, - &max_vec); + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left, + vec_left, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(arg, - l_store, - l_store + vtype::numlanes, - argvec_right, - vec_right, - pivot_vec, - &min_vec, - &max_vec); + amount_gt_pivot = partition_vec(arg, + l_store, + l_store + vtype::numlanes, + argvec_right, + vec_right, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); @@ -221,8 +217,10 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, } template + typename type_t = typename vtype::type_t, + typename argreg_t = typename argtype::reg_t> X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, arrsize_t *arg, arrsize_t left, @@ -232,7 +230,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, type_t *biggest) { if (right - left <= 8 * num_unroll * vtype::numlanes) { - return partition_avx512( + return partition_avx512( arr, arg, left, right, pivot, smallest, biggest); } /* make array length divisible by vtype::numlanes , shortening the array */ @@ -305,14 +303,14 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec[ii], - curr_vec[ii], - pivot_vec, - &min_vec, - &max_vec); + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec[ii], + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); r_store -= amount_gt_pivot; } @@ -322,28 +320,28 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left[ii], - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); r_store -= amount_gt_pivot; } X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_right[ii], - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); r_store -= amount_gt_pivot; } @@ -352,195 +350,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, return l_store; } -template -X86_SIMD_SORT_INLINE void -argsort_8_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - using reg_t = typename vtype::reg_t; - typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - argreg_t argzmm = argtype::maskz_loadu(load_mask, arg); - reg_t arrzmm = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask, argzmm, arr); - arrzmm = sort_zmm_64bit(arrzmm, argzmm); - argtype::mask_storeu(arg, load_mask, argzmm); -} - -template -X86_SIMD_SORT_INLINE void -argsort_16_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - if (N <= 8) { - argsort_8_64bit(arr, arg, N); - return; - } - using reg_t = typename vtype::reg_t; - typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - argreg_t argzmm1 = argtype::loadu(arg); - argreg_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); - reg_t arrzmm1 = vtype::i64gather(arr, arg); - reg_t arrzmm2 = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask, argzmm2, arr); - arrzmm1 = sort_zmm_64bit(arrzmm1, argzmm1); - arrzmm2 = sort_zmm_64bit(arrzmm2, argzmm2); - bitonic_merge_two_zmm_64bit( - arrzmm1, arrzmm2, argzmm1, argzmm2); - argtype::storeu(arg, argzmm1); - argtype::mask_storeu(arg + 8, load_mask, argzmm2); -} - -template -X86_SIMD_SORT_INLINE void -argsort_32_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - if (N <= 16) { - argsort_16_64bit(arr, arg, N); - return; - } - using reg_t = typename vtype::reg_t; - using opmask_t = typename vtype::opmask_t; - reg_t arrzmm[4]; - argreg_t argzmm[4]; - - X86_SIMD_SORT_UNROLL_LOOP(2) - for (int ii = 0; ii < 2; ++ii) { - argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); - arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); - } - - uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; - opmask_t load_mask[2] = {0xFF, 0xFF}; - X86_SIMD_SORT_UNROLL_LOOP(2) - for (int ii = 0; ii < 2; ++ii) { - load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; - argzmm[ii + 2] = argtype::maskz_loadu(load_mask[ii], arg + 16 + 8 * ii); - arrzmm[ii + 2] = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask[ii], argzmm[ii + 2], arr); - arrzmm[ii + 2] = sort_zmm_64bit(arrzmm[ii + 2], - argzmm[ii + 2]); - } - - bitonic_merge_two_zmm_64bit( - arrzmm[0], arrzmm[1], argzmm[0], argzmm[1]); - bitonic_merge_two_zmm_64bit( - arrzmm[2], arrzmm[3], argzmm[2], argzmm[3]); - bitonic_merge_four_zmm_64bit(arrzmm, argzmm); - - argtype::storeu(arg, argzmm[0]); - argtype::storeu(arg + 8, argzmm[1]); - argtype::mask_storeu(arg + 16, load_mask[0], argzmm[2]); - argtype::mask_storeu(arg + 24, load_mask[1], argzmm[3]); -} - -template -X86_SIMD_SORT_INLINE void -argsort_64_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - if (N <= 32) { - argsort_32_64bit(arr, arg, N); - return; - } - using reg_t = typename vtype::reg_t; - using opmask_t = typename vtype::opmask_t; - reg_t arrzmm[8]; - argreg_t argzmm[8]; - - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); - arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); - } - - opmask_t load_mask[4] = {0xFF, 0xFF, 0xFF, 0xFF}; - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; - argzmm[ii + 4] = argtype::maskz_loadu(load_mask[ii], arg + 32 + 8 * ii); - arrzmm[ii + 4] = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask[ii], argzmm[ii + 4], arr); - arrzmm[ii + 4] = sort_zmm_64bit(arrzmm[ii + 4], - argzmm[ii + 4]); - } - - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 8; ii = ii + 2) { - bitonic_merge_two_zmm_64bit( - arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); - } - bitonic_merge_four_zmm_64bit(arrzmm, argzmm); - bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); - bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); - - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - argtype::storeu(arg + 8 * ii, argzmm[ii]); - } - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - argtype::mask_storeu(arg + 32 + 8 * ii, load_mask[ii], argzmm[ii + 4]); - } -} - -/* arsort 128 doesn't seem to make much of a difference to perf*/ -//template -//X86_SIMD_SORT_INLINE void -//argsort_128_64bit(type_t *arr, arrsize_t *arg, int32_t N) -//{ -// if (N <= 64) { -// argsort_64_64bit(arr, arg, N); -// return; -// } -// using reg_t = typename vtype::reg_t; -// using opmask_t = typename vtype::opmask_t; -// reg_t arrzmm[16]; -// argreg_t argzmm[16]; -// -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argzmm[ii] = argtype::loadu(arg + 8*ii); -// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr); -// arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); -// } -// -// opmask_t load_mask[8] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; -// if (N != 128) { -// uarrsize_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF; -// } -// } -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argzmm[ii+8] = argtype::maskz_loadu(load_mask[ii], arg + 64 + 8*ii); -// arrzmm[ii+8] = vtype::template mask_i64gather(vtype::zmm_max(), load_mask[ii], argzmm[ii+8], arr); -// arrzmm[ii+8] = sort_zmm_64bit(arrzmm[ii+8], argzmm[ii+8]); -// } -// -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 16; ii = ii + 2) { -// bitonic_merge_two_zmm_64bit(arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); -// } -// bitonic_merge_four_zmm_64bit(arrzmm, argzmm); -// bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); -// bitonic_merge_four_zmm_64bit(arrzmm + 8, argzmm + 8); -// bitonic_merge_four_zmm_64bit(arrzmm + 12, argzmm + 12); -// bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); -// bitonic_merge_eight_zmm_64bit(arrzmm+8, argzmm+8); -// bitonic_merge_sixteen_zmm_64bit(arrzmm, argzmm); -// -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argtype::storeu(arg + 8*ii, argzmm[ii]); -// } -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argtype::mask_storeu(arg + 64 + 8*ii, load_mask[ii], argzmm[ii + 8]); -// } -//} - template X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, arrsize_t *arg, @@ -568,7 +377,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, } } -template +template X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, arrsize_t *arg, arrsize_t left, @@ -585,22 +394,25 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, /* * Base case: use bitonic networks to sort arrays <= 64 */ - if (right + 1 - left <= 64) { - argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); + if (right + 1 - left <= 256) { + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( + arrsize_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); + argsort_64bit_( + arr, arg, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); + argsort_64bit_( + arr, arg, pivot_index, right, max_iters - 1); } -template +template X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, arrsize_t *arg, arrsize_t pos, @@ -618,20 +430,21 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, /* * Base case: use bitonic networks to sort arrays <= 64 */ - if (right + 1 - left <= 64) { - argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); + if (right + 1 - left <= 256) { + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( + arrsize_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_( + argselect_64bit_( arr, arg, pos, left, pivot_index - 1, max_iters - 1); else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_( + argselect_64bit_( arr, arg, pos, pivot_index, right, max_iters - 1); } @@ -640,9 +453,25 @@ template X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) { + /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ using vectype = typename std::conditional, zmm_vector>::type; + +/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of + * undefined template 'zmm_vector'*/ +#ifdef __APPLE__ + using argtype = + typename std::conditional, + zmm_vector>::type; +#else + using argtype = + typename std::conditional, + zmm_vector>::type; +#endif + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { @@ -651,7 +480,7 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) } } UNUSED(hasnan); - argsort_64bit_( + argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } @@ -674,10 +503,25 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, arrsize_t arrsize, bool hasnan = false) { + /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ using vectype = typename std::conditional, zmm_vector>::type; +/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of + * undefined template 'zmm_vector'*/ +#ifdef __APPLE__ + using argtype = + typename std::conditional, + zmm_vector>::type; +#else + using argtype = + typename std::conditional, + zmm_vector>::type; +#endif + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { @@ -686,7 +530,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, } } UNUSED(hasnan); - argselect_64bit_( + argselect_64bit_( arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 909f3b2b..911f5395 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -8,6 +8,7 @@ #define AVX512_64BIT_COMMON #include "xss-common-includes.h" +#include "avx2-32bit-qsort.hpp" /* * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic @@ -194,6 +195,19 @@ struct ymm_vector { { _mm256_storeu_ps((float *)mem, x); } + static reg_t cast_from(__m256i v) + { + return _mm256_castsi256_ps(v); + } + static __m256i cast_to(reg_t v) + { + return _mm256_castps_si256(v); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } }; template <> struct ymm_vector { @@ -354,6 +368,19 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } }; template <> struct ymm_vector { @@ -514,6 +541,19 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } }; template <> struct zmm_vector { diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 55f79bb1..1f446c68 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -182,356 +182,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t1 *keys, return l_store; } -template -X86_SIMD_SORT_INLINE void -sort_8_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - typename vtype1::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype1::reg_t key_zmm - = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys); - - typename vtype2::reg_t index_zmm - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes); - vtype1::mask_storeu(keys, - load_mask, - sort_zmm_64bit(key_zmm, index_zmm)); - vtype2::mask_storeu(indexes, load_mask, index_zmm); -} - -template -X86_SIMD_SORT_INLINE void -sort_16_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 8) { - sort_8_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using index_type = typename vtype2::reg_t; - - typename vtype1::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - - reg_t key_zmm1 = vtype1::loadu(keys); - reg_t key_zmm2 = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys + 8); - - index_type index_zmm1 = vtype2::loadu(indexes); - index_type index_zmm2 - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes + 8); - - key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); - bitonic_merge_two_zmm_64bit( - key_zmm1, key_zmm2, index_zmm1, index_zmm2); - - vtype2::storeu(indexes, index_zmm1); - vtype2::mask_storeu(indexes + 8, load_mask, index_zmm2); - - vtype1::storeu(keys, key_zmm1); - vtype1::mask_storeu(keys + 8, load_mask, key_zmm2); -} - -template -X86_SIMD_SORT_INLINE void -sort_32_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 16) { - sort_16_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using opmask_t = typename vtype2::opmask_t; - using index_type = typename vtype2::reg_t; - reg_t key_zmm[4]; - index_type index_zmm[4]; - - key_zmm[0] = vtype1::loadu(keys); - key_zmm[1] = vtype1::loadu(keys + 8); - - index_zmm[0] = vtype2::loadu(indexes); - index_zmm[1] = vtype2::loadu(indexes + 8); - - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - key_zmm[2] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 16); - key_zmm[3] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 24); - - index_zmm[2] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 24); - - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - - vtype2::storeu(indexes, index_zmm[0]); - vtype2::storeu(indexes + 8, index_zmm[1]); - vtype2::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); - vtype2::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); - - vtype1::storeu(keys, key_zmm[0]); - vtype1::storeu(keys + 8, key_zmm[1]); - vtype1::mask_storeu(keys + 16, load_mask1, key_zmm[2]); - vtype1::mask_storeu(keys + 24, load_mask2, key_zmm[3]); -} - -template -X86_SIMD_SORT_INLINE void -sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 32) { - sort_32_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using opmask_t = typename vtype1::opmask_t; - using index_type = typename vtype2::reg_t; - reg_t key_zmm[8]; - index_type index_zmm[8]; - - key_zmm[0] = vtype1::loadu(keys); - key_zmm[1] = vtype1::loadu(keys + 8); - key_zmm[2] = vtype1::loadu(keys + 16); - key_zmm[3] = vtype1::loadu(keys + 24); - - index_zmm[0] = vtype2::loadu(indexes); - index_zmm[1] = vtype2::loadu(indexes + 8); - index_zmm[2] = vtype2::loadu(indexes + 16); - index_zmm[3] = vtype2::loadu(indexes + 24); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - // N-32 >= 1 - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - key_zmm[4] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 32); - key_zmm[5] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 40); - key_zmm[6] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 48); - key_zmm[7] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 56); - - index_zmm[4] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 56); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( - key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( - key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - - vtype2::storeu(indexes, index_zmm[0]); - vtype2::storeu(indexes + 8, index_zmm[1]); - vtype2::storeu(indexes + 16, index_zmm[2]); - vtype2::storeu(indexes + 24, index_zmm[3]); - vtype2::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); - vtype2::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); - vtype2::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); - vtype2::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); - - vtype1::storeu(keys, key_zmm[0]); - vtype1::storeu(keys + 8, key_zmm[1]); - vtype1::storeu(keys + 16, key_zmm[2]); - vtype1::storeu(keys + 24, key_zmm[3]); - vtype1::mask_storeu(keys + 32, load_mask1, key_zmm[4]); - vtype1::mask_storeu(keys + 40, load_mask2, key_zmm[5]); - vtype1::mask_storeu(keys + 48, load_mask3, key_zmm[6]); - vtype1::mask_storeu(keys + 56, load_mask4, key_zmm[7]); -} - -template -X86_SIMD_SORT_INLINE void -sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 64) { - sort_64_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using index_type = typename vtype2::reg_t; - using opmask_t = typename vtype1::opmask_t; - reg_t key_zmm[16]; - index_type index_zmm[16]; - - key_zmm[0] = vtype1::loadu(keys); - key_zmm[1] = vtype1::loadu(keys + 8); - key_zmm[2] = vtype1::loadu(keys + 16); - key_zmm[3] = vtype1::loadu(keys + 24); - key_zmm[4] = vtype1::loadu(keys + 32); - key_zmm[5] = vtype1::loadu(keys + 40); - key_zmm[6] = vtype1::loadu(keys + 48); - key_zmm[7] = vtype1::loadu(keys + 56); - - index_zmm[0] = vtype2::loadu(indexes); - index_zmm[1] = vtype2::loadu(indexes + 8); - index_zmm[2] = vtype2::loadu(indexes + 16); - index_zmm[3] = vtype2::loadu(indexes + 24); - index_zmm[4] = vtype2::loadu(indexes + 32); - index_zmm[5] = vtype2::loadu(indexes + 40); - index_zmm[6] = vtype2::loadu(indexes + 48); - index_zmm[7] = vtype2::loadu(indexes + 56); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; - opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; - if (N != 128) { - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - load_mask5 = (combined_mask >> 32) & 0xFF; - load_mask6 = (combined_mask >> 40) & 0xFF; - load_mask7 = (combined_mask >> 48) & 0xFF; - load_mask8 = (combined_mask >> 56) & 0xFF; - } - key_zmm[8] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 64); - key_zmm[9] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 72); - key_zmm[10] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 80); - key_zmm[11] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 88); - key_zmm[12] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask5, keys + 96); - key_zmm[13] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask6, keys + 104); - key_zmm[14] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask7, keys + 112); - key_zmm[15] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask8, keys + 120); - - index_zmm[8] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask8, indexes + 120); - key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( - key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( - key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_two_zmm_64bit( - key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); - bitonic_merge_two_zmm_64bit( - key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); - bitonic_merge_two_zmm_64bit( - key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); - bitonic_merge_two_zmm_64bit( - key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); - vtype2::storeu(indexes, index_zmm[0]); - vtype2::storeu(indexes + 8, index_zmm[1]); - vtype2::storeu(indexes + 16, index_zmm[2]); - vtype2::storeu(indexes + 24, index_zmm[3]); - vtype2::storeu(indexes + 32, index_zmm[4]); - vtype2::storeu(indexes + 40, index_zmm[5]); - vtype2::storeu(indexes + 48, index_zmm[6]); - vtype2::storeu(indexes + 56, index_zmm[7]); - vtype2::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); - vtype2::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); - vtype2::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); - vtype2::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); - vtype2::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); - vtype2::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); - vtype2::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); - vtype2::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); - - vtype1::storeu(keys, key_zmm[0]); - vtype1::storeu(keys + 8, key_zmm[1]); - vtype1::storeu(keys + 16, key_zmm[2]); - vtype1::storeu(keys + 24, key_zmm[3]); - vtype1::storeu(keys + 32, key_zmm[4]); - vtype1::storeu(keys + 40, key_zmm[5]); - vtype1::storeu(keys + 48, key_zmm[6]); - vtype1::storeu(keys + 56, key_zmm[7]); - vtype1::mask_storeu(keys + 64, load_mask1, key_zmm[8]); - vtype1::mask_storeu(keys + 72, load_mask2, key_zmm[9]); - vtype1::mask_storeu(keys + 80, load_mask3, key_zmm[10]); - vtype1::mask_storeu(keys + 88, load_mask4, key_zmm[11]); - vtype1::mask_storeu(keys + 96, load_mask5, key_zmm[12]); - vtype1::mask_storeu(keys + 104, load_mask6, key_zmm[13]); - vtype1::mask_storeu(keys + 112, load_mask7, key_zmm[14]); - vtype1::mask_storeu(keys + 120, load_mask8, key_zmm[15]); -} - template ( + kvsort_n( keys + left, indexes + left, (int32_t)(right + 1 - left)); return; } diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index cec1cb7a..fccf33dc 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -1,5 +1,8 @@ -#ifndef AVX512_KEYVALUE_NETWORKS -#define AVX512_KEYVALUE_NETWORKS +#ifndef XSS_KEYVALUE_NETWORKS +#define XSS_KEYVALUE_NETWORKS + +#include "avx512-64bit-qsort.hpp" +#include "avx2-64bit-qsort.hpp" template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(reg_t &key_zmm1, - reg_t &key_zmm2, - index_type &index_zmm1, - index_type &index_zmm2) + +template +X86_SIMD_SORT_INLINE void +bitonic_merge_dispatch(typename keyType::reg_t &key, + typename valueType::reg_t &value) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - key_zmm2 = vtype1::permutexvar(rev_index1, key_zmm2); - index_zmm2 = vtype2::permutexvar(rev_index2, index_zmm2); + constexpr int numlanes = keyType::numlanes; + if constexpr (numlanes == 8) { + key = bitonic_merge_zmm_64bit(key, value); + } + else { + static_assert(numlanes == -1, "should not reach here"); + UNUSED(key); + UNUSED(value); + } +} - reg_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); - reg_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); +template +X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, + typename valueType::reg_t &value) +{ + constexpr int numlanes = keyType::numlanes; + if constexpr (numlanes == 8) { + key = sort_zmm_64bit(key, value); + } + else { + static_assert(numlanes == -1, "should not reach here"); + UNUSED(key); + UNUSED(value); + } +} - typename vtype1::opmask_t movmask = vtype1::eq(key_zmm3, key_zmm1); +template +X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(typename keyType::reg_t *keys, + typename valueType::reg_t *values) +{ + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int num = numVecs / 2; num >= 2; num /= 2) { + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int j = 0; j < numVecs; j += num) { + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < num / 2; i++) { + arrsize_t index1 = i + j; + arrsize_t index2 = i + j + num / 2; + COEX(keys[index1], + keys[index2], + values[index1], + values[index2]); + } + } + } +} - index_type index_zmm3 = vtype2::mask_mov(index_zmm2, movmask, index_zmm1); - index_type index_zmm4 = vtype2::mask_mov(index_zmm1, movmask, index_zmm2); +template +X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(typename keyType::reg_t *keys, + typename valueType::reg_t *values) +{ + // Do the reverse part + if constexpr (numVecs == 2) { + keys[1] = keyType::reverse(keys[1]); + values[1] = valueType::reverse(values[1]); + COEX(keys[0], keys[1], values[0], values[1]); + keys[1] = keyType::reverse(keys[1]); + values[1] = valueType::reverse(values[1]); + } + else if constexpr (numVecs > 2) { + // Reverse upper half + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); + values[numVecs - i - 1] + = valueType::reverse(values[numVecs - i - 1]); + + COEX(keys[i], + keys[numVecs - i - 1], + values[i], + values[numVecs - i - 1]); + + keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); + values[numVecs - i - 1] + = valueType::reverse(values[numVecs - i - 1]); + } + } + + // Call cleaner + bitonic_clean_n_vec(keys, values); + + // Now do bitonic_merge + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + bitonic_merge_dispatch(keys[i], values[i]); + } +} - /* need to reverse the lower registers to keep the correct order */ - key_zmm4 = vtype1::permutexvar(rev_index1, key_zmm4); - index_zmm4 = vtype2::permutexvar(rev_index2, index_zmm4); +template +X86_SIMD_SORT_INLINE void +bitonic_fullmerge_n_vec(typename keyType::reg_t *keys, + typename valueType::reg_t *values) +{ + if constexpr (numPer > numVecs) { + UNUSED(keys); + UNUSED(values); + return; + } + else { + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / numPer; i++) { + bitonic_merge_n_vec( + keys + i * numPer, values + i * numPer); + } + bitonic_fullmerge_n_vec( + keys, values); + } +} - // 2) Recursive half cleaner for each - key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); - index_zmm1 = index_zmm3; - index_zmm2 = index_zmm4; +template +X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, + typename indexType::type_t *indices, + int N) +{ + using kreg_t = typename keyType::reg_t; + using ireg_t = typename indexType::reg_t; + + static_assert(numVecs > 0, "numVecs should be > 0"); + if constexpr (numVecs > 1) { + if (N * 2 <= numVecs * keyType::numlanes) { + argsort_n_vec(keys, indices, N); + return; + } + } + + kreg_t keyVecs[numVecs]; + ireg_t indexVecs[numVecs]; + + // Generate masks for loading and storing + typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), + (uint64_t)keyType::numlanes); + ioMasks[j] = keyType::get_partial_loadmask(num_to_read); + } + + // Unmasked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + indexVecs[i] = indexType::loadu(indices + i * indexType::numlanes); + keyVecs[i] + = keyType::i64gather(keys, indices + i * indexType::numlanes); + } + // Masked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + indexVecs[i] = indexType::mask_loadu(indexType::zmm_max(), + ioMasks[j], + indices + i * indexType::numlanes); + + keyVecs[i] = keyType::template mask_i64gather( + keyType::zmm_max(), ioMasks[j], indexVecs[i], keys); + } + + // Sort each loaded vector + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + sort_vec_dispatch(keyVecs[i], indexVecs[i]); + } + + // Run the full merger + bitonic_fullmerge_n_vec(keyVecs, indexVecs); + + // Unmasked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + indexType::storeu(indices + i * indexType::numlanes, indexVecs[i]); + } + // Masked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + indexType::mask_storeu( + indices + i * indexType::numlanes, ioMasks[j], indexVecs[i]); + } } -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(reg_t *key_zmm, - index_type *index_zmm) + +template +X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, + typename valueType::type_t *values, + int N) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - // 1) First step of a merging network - reg_t key_zmm2r = vtype1::permutexvar(rev_index1, key_zmm[2]); - reg_t key_zmm3r = vtype1::permutexvar(rev_index1, key_zmm[3]); - index_type index_zmm2r = vtype2::permutexvar(rev_index2, index_zmm[2]); - index_type index_zmm3r = vtype2::permutexvar(rev_index2, index_zmm[3]); - - reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm3r); - reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm2r); - reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); - reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); - - typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); - typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); - - index_type index_reg_t1 - = vtype2::mask_mov(index_zmm3r, movmask1, index_zmm[0]); - index_type index_zmm_m1 - = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm3r); - index_type index_reg_t2 - = vtype2::mask_mov(index_zmm2r, movmask2, index_zmm[1]); - index_type index_zmm_m2 - = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r); - - // 2) Recursive half clearer: 16 - reg_t key_reg_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_reg_t4 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_reg_t3 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_reg_t4 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - reg_t key_zmm0 = vtype1::min(key_reg_t1, key_reg_t2); - reg_t key_zmm1 = vtype1::max(key_reg_t1, key_reg_t2); - reg_t key_zmm2 = vtype1::min(key_reg_t3, key_reg_t4); - reg_t key_zmm3 = vtype1::max(key_reg_t3, key_reg_t4); - - movmask1 = vtype1::eq(key_zmm0, key_reg_t1); - movmask2 = vtype1::eq(key_zmm2, key_reg_t3); - - index_type index_zmm0 - = vtype2::mask_mov(index_reg_t2, movmask1, index_reg_t1); - index_type index_zmm1 - = vtype2::mask_mov(index_reg_t1, movmask1, index_reg_t2); - index_type index_zmm2 - = vtype2::mask_mov(index_reg_t4, movmask2, index_reg_t3); - index_type index_zmm3 - = vtype2::mask_mov(index_reg_t3, movmask2, index_reg_t4); - - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - - index_zmm[0] = index_zmm0; - index_zmm[1] = index_zmm1; - index_zmm[2] = index_zmm2; - index_zmm[3] = index_zmm3; + using kreg_t = typename keyType::reg_t; + using vreg_t = typename valueType::reg_t; + + static_assert(numVecs > 0, "numVecs should be > 0"); + if constexpr (numVecs > 1) { + if (N * 2 <= numVecs * keyType::numlanes) { + kvsort_n_vec(keys, values, N); + return; + } + } + + kreg_t keyVecs[numVecs]; + vreg_t valueVecs[numVecs]; + + // Generate masks for loading and storing + typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), + (uint64_t)keyType::numlanes); + ioMasks[j] = keyType::get_partial_loadmask(num_to_read); + } + + // Unmasked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + keyVecs[i] = keyType::loadu(keys + i * keyType::numlanes); + valueVecs[i] = valueType::loadu(values + i * valueType::numlanes); + } + // Masked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + keyVecs[i] = keyType::mask_loadu( + keyType::zmm_max(), ioMasks[j], keys + i * keyType::numlanes); + valueVecs[i] = valueType::mask_loadu(valueType::zmm_max(), + ioMasks[j], + values + i * valueType::numlanes); + } + + // Sort each loaded vector + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + sort_vec_dispatch(keyVecs[i], valueVecs[i]); + } + + // Run the full merger + bitonic_fullmerge_n_vec(keyVecs, valueVecs); + + // Unmasked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + keyType::storeu(keys + i * keyType::numlanes, keyVecs[i]); + valueType::storeu(values + i * valueType::numlanes, valueVecs[i]); + } + // Masked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + keyType::mask_storeu( + keys + i * keyType::numlanes, ioMasks[j], keyVecs[i]); + valueType::mask_storeu( + values + i * valueType::numlanes, ioMasks[j], valueVecs[i]); + } } -template -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(reg_t *key_zmm, - index_type *index_zmm) +template +X86_SIMD_SORT_INLINE void +argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - reg_t key_zmm4r = vtype1::permutexvar(rev_index1, key_zmm[4]); - reg_t key_zmm5r = vtype1::permutexvar(rev_index1, key_zmm[5]); - reg_t key_zmm6r = vtype1::permutexvar(rev_index1, key_zmm[6]); - reg_t key_zmm7r = vtype1::permutexvar(rev_index1, key_zmm[7]); - index_type index_zmm4r = vtype2::permutexvar(rev_index2, index_zmm[4]); - index_type index_zmm5r = vtype2::permutexvar(rev_index2, index_zmm[5]); - index_type index_zmm6r = vtype2::permutexvar(rev_index2, index_zmm[6]); - index_type index_zmm7r = vtype2::permutexvar(rev_index2, index_zmm[7]); - - reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm7r); - reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm6r); - reg_t key_reg_t3 = vtype1::min(key_zmm[2], key_zmm5r); - reg_t key_reg_t4 = vtype1::min(key_zmm[3], key_zmm4r); - - reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); - reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); - reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); - reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); - - typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); - typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); - typename vtype1::opmask_t movmask3 = vtype1::eq(key_reg_t3, key_zmm[2]); - typename vtype1::opmask_t movmask4 = vtype1::eq(key_reg_t4, key_zmm[3]); - - index_type index_reg_t1 - = vtype2::mask_mov(index_zmm7r, movmask1, index_zmm[0]); - index_type index_zmm_m1 - = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm7r); - index_type index_reg_t2 - = vtype2::mask_mov(index_zmm6r, movmask2, index_zmm[1]); - index_type index_zmm_m2 - = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm6r); - index_type index_reg_t3 - = vtype2::mask_mov(index_zmm5r, movmask3, index_zmm[2]); - index_type index_zmm_m3 - = vtype2::mask_mov(index_zmm[2], movmask3, index_zmm5r); - index_type index_reg_t4 - = vtype2::mask_mov(index_zmm4r, movmask4, index_zmm[3]); - index_type index_zmm_m4 - = vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r); - - reg_t key_reg_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); - reg_t key_reg_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); - reg_t key_reg_t7 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_reg_t8 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_reg_t5 = vtype2::permutexvar(rev_index2, index_zmm_m4); - index_type index_reg_t6 = vtype2::permutexvar(rev_index2, index_zmm_m3); - index_type index_reg_t7 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_reg_t8 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); - COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); - COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); - COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); - COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); - COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); - COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); - COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); - key_zmm[0] - = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); - key_zmm[1] - = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); - key_zmm[2] - = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); - key_zmm[3] - = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); - key_zmm[4] - = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); - key_zmm[5] - = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); - key_zmm[6] - = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); - key_zmm[7] - = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); - - index_zmm[0] = index_reg_t1; - index_zmm[1] = index_reg_t2; - index_zmm[2] = index_reg_t3; - index_zmm[3] = index_reg_t4; - index_zmm[4] = index_reg_t5; - index_zmm[5] = index_reg_t6; - index_zmm[6] = index_reg_t7; - index_zmm[7] = index_reg_t8; + static_assert(keyType::numlanes == indexType::numlanes, + "invalid pairing of value/index types"); + + constexpr int numVecs = maxN / keyType::numlanes; + constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); + constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); + static_assert(powerOfTwo == true && isMultiple == true, + "maxN must be keyType::numlanes times a power of 2"); + + argsort_n_vec(keys, indices, N); } -template -X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(reg_t *key_zmm, - index_type *index_zmm) +template +X86_SIMD_SORT_INLINE void kvsort_n(typename keyType::type_t *keys, + typename valueType::type_t *values, + int N) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - reg_t key_zmm8r = vtype1::permutexvar(rev_index1, key_zmm[8]); - reg_t key_zmm9r = vtype1::permutexvar(rev_index1, key_zmm[9]); - reg_t key_zmm10r = vtype1::permutexvar(rev_index1, key_zmm[10]); - reg_t key_zmm11r = vtype1::permutexvar(rev_index1, key_zmm[11]); - reg_t key_zmm12r = vtype1::permutexvar(rev_index1, key_zmm[12]); - reg_t key_zmm13r = vtype1::permutexvar(rev_index1, key_zmm[13]); - reg_t key_zmm14r = vtype1::permutexvar(rev_index1, key_zmm[14]); - reg_t key_zmm15r = vtype1::permutexvar(rev_index1, key_zmm[15]); - - index_type index_zmm8r = vtype2::permutexvar(rev_index2, index_zmm[8]); - index_type index_zmm9r = vtype2::permutexvar(rev_index2, index_zmm[9]); - index_type index_zmm10r = vtype2::permutexvar(rev_index2, index_zmm[10]); - index_type index_zmm11r = vtype2::permutexvar(rev_index2, index_zmm[11]); - index_type index_zmm12r = vtype2::permutexvar(rev_index2, index_zmm[12]); - index_type index_zmm13r = vtype2::permutexvar(rev_index2, index_zmm[13]); - index_type index_zmm14r = vtype2::permutexvar(rev_index2, index_zmm[14]); - index_type index_zmm15r = vtype2::permutexvar(rev_index2, index_zmm[15]); - - reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm15r); - reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm14r); - reg_t key_reg_t3 = vtype1::min(key_zmm[2], key_zmm13r); - reg_t key_reg_t4 = vtype1::min(key_zmm[3], key_zmm12r); - reg_t key_reg_t5 = vtype1::min(key_zmm[4], key_zmm11r); - reg_t key_reg_t6 = vtype1::min(key_zmm[5], key_zmm10r); - reg_t key_reg_t7 = vtype1::min(key_zmm[6], key_zmm9r); - reg_t key_reg_t8 = vtype1::min(key_zmm[7], key_zmm8r); - - reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); - reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); - reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); - reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); - reg_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); - reg_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); - reg_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); - reg_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); - - index_type index_reg_t1 = vtype2::mask_mov( - index_zmm15r, vtype1::eq(key_reg_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], vtype1::eq(key_reg_t1, key_zmm[0]), index_zmm15r); - index_type index_reg_t2 = vtype2::mask_mov( - index_zmm14r, vtype1::eq(key_reg_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], vtype1::eq(key_reg_t2, key_zmm[1]), index_zmm14r); - index_type index_reg_t3 = vtype2::mask_mov( - index_zmm13r, vtype1::eq(key_reg_t3, key_zmm[2]), index_zmm[2]); - index_type index_zmm_m3 = vtype2::mask_mov( - index_zmm[2], vtype1::eq(key_reg_t3, key_zmm[2]), index_zmm13r); - index_type index_reg_t4 = vtype2::mask_mov( - index_zmm12r, vtype1::eq(key_reg_t4, key_zmm[3]), index_zmm[3]); - index_type index_zmm_m4 = vtype2::mask_mov( - index_zmm[3], vtype1::eq(key_reg_t4, key_zmm[3]), index_zmm12r); - - index_type index_reg_t5 = vtype2::mask_mov( - index_zmm11r, vtype1::eq(key_reg_t5, key_zmm[4]), index_zmm[4]); - index_type index_zmm_m5 = vtype2::mask_mov( - index_zmm[4], vtype1::eq(key_reg_t5, key_zmm[4]), index_zmm11r); - index_type index_reg_t6 = vtype2::mask_mov( - index_zmm10r, vtype1::eq(key_reg_t6, key_zmm[5]), index_zmm[5]); - index_type index_zmm_m6 = vtype2::mask_mov( - index_zmm[5], vtype1::eq(key_reg_t6, key_zmm[5]), index_zmm10r); - index_type index_reg_t7 = vtype2::mask_mov( - index_zmm9r, vtype1::eq(key_reg_t7, key_zmm[6]), index_zmm[6]); - index_type index_zmm_m7 = vtype2::mask_mov( - index_zmm[6], vtype1::eq(key_reg_t7, key_zmm[6]), index_zmm9r); - index_type index_reg_t8 = vtype2::mask_mov( - index_zmm8r, vtype1::eq(key_reg_t8, key_zmm[7]), index_zmm[7]); - index_type index_zmm_m8 = vtype2::mask_mov( - index_zmm[7], vtype1::eq(key_reg_t8, key_zmm[7]), index_zmm8r); - - reg_t key_reg_t9 = vtype1::permutexvar(rev_index1, key_zmm_m8); - reg_t key_reg_t10 = vtype1::permutexvar(rev_index1, key_zmm_m7); - reg_t key_reg_t11 = vtype1::permutexvar(rev_index1, key_zmm_m6); - reg_t key_reg_t12 = vtype1::permutexvar(rev_index1, key_zmm_m5); - reg_t key_reg_t13 = vtype1::permutexvar(rev_index1, key_zmm_m4); - reg_t key_reg_t14 = vtype1::permutexvar(rev_index1, key_zmm_m3); - reg_t key_reg_t15 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_reg_t16 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_reg_t9 = vtype2::permutexvar(rev_index2, index_zmm_m8); - index_type index_reg_t10 = vtype2::permutexvar(rev_index2, index_zmm_m7); - index_type index_reg_t11 = vtype2::permutexvar(rev_index2, index_zmm_m6); - index_type index_reg_t12 = vtype2::permutexvar(rev_index2, index_zmm_m5); - index_type index_reg_t13 = vtype2::permutexvar(rev_index2, index_zmm_m4); - index_type index_reg_t14 = vtype2::permutexvar(rev_index2, index_zmm_m3); - index_type index_reg_t15 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_reg_t16 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - COEX(key_reg_t1, key_reg_t5, index_reg_t1, index_reg_t5); - COEX(key_reg_t2, key_reg_t6, index_reg_t2, index_reg_t6); - COEX(key_reg_t3, key_reg_t7, index_reg_t3, index_reg_t7); - COEX(key_reg_t4, key_reg_t8, index_reg_t4, index_reg_t8); - COEX(key_reg_t9, key_reg_t13, index_reg_t9, index_reg_t13); - COEX( - key_reg_t10, key_reg_t14, index_reg_t10, index_reg_t14); - COEX( - key_reg_t11, key_reg_t15, index_reg_t11, index_reg_t15); - COEX( - key_reg_t12, key_reg_t16, index_reg_t12, index_reg_t16); - - COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); - COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); - COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); - COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); - COEX(key_reg_t9, key_reg_t11, index_reg_t9, index_reg_t11); - COEX( - key_reg_t10, key_reg_t12, index_reg_t10, index_reg_t12); - COEX( - key_reg_t13, key_reg_t15, index_reg_t13, index_reg_t15); - COEX( - key_reg_t14, key_reg_t16, index_reg_t14, index_reg_t16); - - COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); - COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); - COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); - COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); - COEX(key_reg_t9, key_reg_t10, index_reg_t9, index_reg_t10); - COEX( - key_reg_t11, key_reg_t12, index_reg_t11, index_reg_t12); - COEX( - key_reg_t13, key_reg_t14, index_reg_t13, index_reg_t14); - COEX( - key_reg_t15, key_reg_t16, index_reg_t15, index_reg_t16); - // - key_zmm[0] - = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); - key_zmm[1] - = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); - key_zmm[2] - = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); - key_zmm[3] - = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); - key_zmm[4] - = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); - key_zmm[5] - = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); - key_zmm[6] - = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); - key_zmm[7] - = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); - key_zmm[8] - = bitonic_merge_zmm_64bit(key_reg_t9, index_reg_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_reg_t10, - index_reg_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_reg_t11, - index_reg_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_reg_t12, - index_reg_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_reg_t13, - index_reg_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_reg_t14, - index_reg_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_reg_t15, - index_reg_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_reg_t16, - index_reg_t16); - - index_zmm[0] = index_reg_t1; - index_zmm[1] = index_reg_t2; - index_zmm[2] = index_reg_t3; - index_zmm[3] = index_reg_t4; - index_zmm[4] = index_reg_t5; - index_zmm[5] = index_reg_t6; - index_zmm[6] = index_reg_t7; - index_zmm[7] = index_reg_t8; - index_zmm[8] = index_reg_t9; - index_zmm[9] = index_reg_t10; - index_zmm[10] = index_reg_t11; - index_zmm[11] = index_reg_t12; - index_zmm[12] = index_reg_t13; - index_zmm[13] = index_reg_t14; - index_zmm[14] = index_reg_t15; - index_zmm[15] = index_reg_t16; + static_assert(keyType::numlanes == valueType::numlanes, + "invalid pairing of key/value types"); + + constexpr int numVecs = maxN / keyType::numlanes; + constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); + constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); + static_assert(powerOfTwo == true && isMultiple == true, + "maxN must be keyType::numlanes times a power of 2"); + + kvsort_n_vec(keys, values, N); } -#endif // AVX512_KEYVALUE_NETWORKS + +#endif