diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 2afc4d1d..4c1123e4 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -1,4 +1,5 @@ // AVX2 specific routines: + #include "x86simdsort-static-incl.h" #include "x86simdsort-internal.h" @@ -33,6 +34,38 @@ return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ } +#define DEFINE_KEYVALUE_METHODS(type) \ + template <> \ + void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \ + { \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \ + { \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \ + { \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \ + { \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \ + { \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \ + { \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + } + namespace xss { namespace avx2 { DEFINE_ALL_METHODS(uint32_t) @@ -41,5 +74,11 @@ namespace avx2 { DEFINE_ALL_METHODS(uint64_t) DEFINE_ALL_METHODS(int64_t) DEFINE_ALL_METHODS(double) + DEFINE_KEYVALUE_METHODS(uint64_t) + DEFINE_KEYVALUE_METHODS(int64_t) + DEFINE_KEYVALUE_METHODS(double) + DEFINE_KEYVALUE_METHODS(uint32_t) + DEFINE_KEYVALUE_METHODS(int32_t) + DEFINE_KEYVALUE_METHODS(float) } // namespace avx2 -} // namespace xss +} // namespace xss \ No newline at end of file diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 829dd7b8..e51c51ed 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -1,4 +1,5 @@ // SKX specific routines: + #include "x86simdsort-static-incl.h" #include "x86simdsort-internal.h" diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 0c16f148..2f268abc 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -208,12 +208,12 @@ DISPATCH_ALL(argselect, (ISA_LIST("avx512_skx", "avx2"))) #define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \ - DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx"))) + DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_KEYVALUE_SORT_FORTYPE(uint64_t) DISPATCH_KEYVALUE_SORT_FORTYPE(int64_t) diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index b8cca7a4..56ffc232 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -99,6 +99,10 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask) + { + return convert_int_to_avx2_mask(intMask); + } static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { @@ -268,6 +272,10 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask) + { + return convert_int_to_avx2_mask(intMask); + } static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { @@ -444,6 +452,10 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask) + { + return convert_int_to_avx2_mask(intMask); + } static int32_t convert_mask_to_int(opmask_t mask) { return convert_avx2_mask_to_int(mask); diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 8f5fdce9..6e2b6eff 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -24,6 +24,7 @@ template X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm); struct avx512_64bit_swizzle_ops; +struct avx512_ymm_64bit_swizzle_ops; template <> struct ymm_vector { @@ -34,6 +35,8 @@ struct ymm_vector { static const uint8_t numlanes = 8; static constexpr simd_type vec_type = simd_type::AVX512; + using swizzle_ops = avx512_ymm_64bit_swizzle_ops; + static type_t type_max() { return X86_SIMD_SORT_INFINITYF; @@ -212,6 +215,14 @@ struct ymm_vector { const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); return permutexvar(rev_index, ymm); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct ymm_vector { @@ -222,6 +233,8 @@ struct ymm_vector { static const uint8_t numlanes = 8; static constexpr simd_type vec_type = simd_type::AVX512; + using swizzle_ops = avx512_ymm_64bit_swizzle_ops; + static type_t type_max() { return X86_SIMD_SORT_MAX_UINT32; @@ -386,6 +399,14 @@ struct ymm_vector { const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); return permutexvar(rev_index, ymm); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct ymm_vector { @@ -396,6 +417,8 @@ struct ymm_vector { static const uint8_t numlanes = 8; static constexpr simd_type vec_type = simd_type::AVX512; + using swizzle_ops = avx512_ymm_64bit_swizzle_ops; + static type_t type_max() { return X86_SIMD_SORT_MAX_INT32; @@ -560,6 +583,14 @@ struct ymm_vector { const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); return permutexvar(rev_index, ymm); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -1215,4 +1246,77 @@ struct avx512_64bit_swizzle_ops { } }; +struct avx512_ymm_64bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b10110001); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 4) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b01001110); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 8) { + v = _mm256_permute2x128_si256(v, v, 0b00000001); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + constexpr uint64_t mask = 0b00011011; + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, mask); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 8) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m256i v1 = vtype::cast_to(reg); + __m256i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm256_blend_epi32(v1, v2, 0b01010101); + } + else if constexpr (scale == 4) { + v1 = _mm256_blend_epi32(v1, v2, 0b00110011); + } + else if constexpr (scale == 8) { + v1 = _mm256_blend_epi32(v1, v2, 0b00001111); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; + #endif diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 0d0e4400..1f849004 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -100,6 +100,12 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false); std::iota(indices.begin(), indices.end(), 0); \ x86simdsortStatic::argselect(arr, indices.data(), k, size, hasnan); \ return indices; \ + } \ + template \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_qsort( \ + T1 *key, T2 *val, size_t size, bool hasnan) \ + { \ + ISA##_qsort_kv(key, val, size, hasnan); \ } /* @@ -107,13 +113,13 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false); */ #include "xss-common-qsort.h" #include "xss-common-argsort.h" +#include "xss-common-keyvaluesort.hpp" #if defined(__AVX512DQ__) && defined(__AVX512VL__) /* 32-bit and 64-bit dtypes vector definitions on SKX */ #include "avx512-32bit-qsort.hpp" #include "avx512-64bit-qsort.hpp" #include "avx512-64bit-argsort.hpp" -#include "avx512-64bit-keyvaluesort.hpp" /* 16-bit dtypes vector definitions on ICL */ #if defined(__AVX512BW__) && defined(__AVX512VBMI2__) @@ -126,14 +132,6 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false); XSS_METHODS(avx512) -// key-value currently only on avx512 -template -X86_SIMD_SORT_FINLINE void -x86simdsortStatic::keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan) -{ - avx512_qsort_kv(key, val, size, hasnan); -} - #elif defined(__AVX512F__) #error "x86simdsort requires AVX512DQ and AVX512VL to be enabled in addition to AVX512F to use AVX512" diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp similarity index 90% rename from src/avx512-64bit-keyvaluesort.hpp rename to src/xss-common-keyvaluesort.hpp index 9c4f1459..4699b8a1 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -8,7 +8,7 @@ #ifndef AVX512_QSORT_64BIT_KV #define AVX512_QSORT_64BIT_KV -#include "avx512-64bit-common.h" +#include "xss-common-qsort.h" #include "xss-network-keyvaluesort.hpp" /* @@ -33,15 +33,14 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t1 *keys, { /* which elements are larger than the pivot */ typename vtype1::opmask_t gt_mask = vtype1::ge(keys_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - vtype1::mask_compressstoreu( - keys + left, vtype1::knot_opmask(gt_mask), keys_vec); - vtype1::mask_compressstoreu( - keys + right - amount_gt_pivot, gt_mask, keys_vec); - vtype2::mask_compressstoreu( - indexes + left, vtype2::knot_opmask(gt_mask), indexes_vec); - vtype2::mask_compressstoreu( - indexes + right - amount_gt_pivot, gt_mask, indexes_vec); + + int32_t amount_gt_pivot = vtype1::double_compressstore( + keys + left, keys + right - vtype1::numlanes, gt_mask, keys_vec); + vtype2::double_compressstore(indexes + left, + indexes + right - vtype2::numlanes, + resize_mask(gt_mask), + indexes_vec); + *smallest_vec = vtype1::min(keys_vec, *smallest_vec); *biggest_vec = vtype1::max(keys_vec, *biggest_vec); return amount_gt_pivot; @@ -363,11 +362,11 @@ template -X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys, - type2_t *indexes, - arrsize_t left, - arrsize_t right, - int max_iters) +X86_SIMD_SORT_INLINE void kvsort_(type1_t *keys, + type2_t *indexes, + arrsize_t left, + arrsize_t right, + int max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -393,32 +392,35 @@ X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys, arrsize_t pivot_index = kvpartition_unrolled( keys, indexes, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) { - qsort_64bit_( + kvsort_( keys, indexes, left, pivot_index - 1, max_iters - 1); } if (pivot != biggest) { - qsort_64bit_( + kvsort_( keys, indexes, pivot_index, right, max_iters - 1); } } -template +template + typename full_vector, + template + typename half_vector> X86_SIMD_SORT_INLINE void -avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) +xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan) { using keytype = typename std::conditional, - zmm_vector>::type; + half_vector, + full_vector>::type; using valtype = typename std::conditional, - zmm_vector>::type; -/* - * Enable testing the heapsort key-value sort in the CI: - */ + half_vector, + full_vector>::type; + #ifdef XSS_TEST_KEYVALUE_BASE_CASE int maxiters = -1; bool minarrsize = true; @@ -431,14 +433,31 @@ avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) arrsize_t nan_count = 0; if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf>(keys, arrsize); + nan_count + = replace_nan_with_inf>(keys, arrsize); } } else { UNUSED(hasnan); } - qsort_64bit_(keys, indexes, 0, arrsize - 1, maxiters); + kvsort_(keys, indexes, 0, arrsize - 1, maxiters); replace_inf_with_nan(keys, arrsize, nan_count); } } + +template +X86_SIMD_SORT_INLINE void +avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) +{ + xss_qsort_kv( + keys, indexes, arrsize, hasnan); +} + +template +X86_SIMD_SORT_INLINE void +avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) +{ + xss_qsort_kv( + keys, indexes, arrsize, hasnan); +} #endif // AVX512_QSORT_64BIT_KV diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index bb9b9bcd..3cb30378 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -1,6 +1,21 @@ #ifndef XSS_KEYVALUE_NETWORKS #define XSS_KEYVALUE_NETWORKS +#include "xss-common-includes.h" + +template +typename vtype::opmask_t convert_int_to_mask(maskType mask) +{ + if constexpr (vtype::vec_type == simd_type::AVX512) { return mask; } + else if constexpr (vtype::vec_type == simd_type::AVX2) { + return vtype::convert_int_to_mask(mask); + } + else { + static_assert(always_false, + "Error in func convert_int_to_mask"); + } +} + template typename valueType::opmask_t resize_mask(typename keyType::opmask_t mask) { @@ -70,66 +85,74 @@ template (0xAAAA); + const auto oxCCCC = convert_int_to_mask(0xCCCC); + const auto oxFF00 = convert_int_to_mask(0xFF00); + const auto oxF0F0 = convert_int_to_mask(0xF0F0); + key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xCCCC); + index_swizzle::template reverse_n(index_zmm), + oxCCCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_3), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_3), index_zmm), - 0xF0F0); + oxF0F0); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xCCCC); + index_swizzle::template swap_n(index_zmm), + oxCCCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_5), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_5), index_zmm), - 0xFF00); + oxFF00); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_6), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_6), index_zmm), - 0xF0F0); + oxF0F0); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xCCCC); + index_swizzle::template swap_n(index_zmm), + oxCCCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); return key_zmm; } @@ -141,30 +164,38 @@ template (0xAAAA); + const auto oxCCCC = convert_int_to_mask(0xCCCC); + const auto oxFF00 = convert_int_to_mask(0xFF00); + const auto oxF0F0 = convert_int_to_mask(0xF0F0); + key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_7), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_7), index_zmm), - 0xFF00); + oxFF00); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_6), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_6), index_zmm), - 0xF0F0); + oxF0F0); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xCCCC); + index_swizzle::template swap_n(index_zmm), + oxCCCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); return key_zmm; } @@ -174,44 +205,48 @@ template X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t key_zmm, index_type &index_zmm) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const auto oxAA = convert_int_to_mask(0xAA); + const auto oxCC = convert_int_to_mask(0xCC); + const auto oxF0 = convert_int_to_mask(0xF0); + key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); + index_swizzle::template swap_n(index_zmm), + oxAA); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_1), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_1), index_zmm), - 0xCC); - key_zmm = cmp_merge( - key_zmm, - vtype1::template shuffle(key_zmm), - index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); + oxCC); key_zmm = cmp_merge( key_zmm, - vtype1::permutexvar(rev_index1, key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::permutexvar(rev_index2, index_zmm), - 0xF0); + index_swizzle::template swap_n(index_zmm), + oxAA); + key_zmm = cmp_merge(key_zmm, + vtype1::reverse(key_zmm), + index_zmm, + vtype2::reverse(index_zmm), + oxF0); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_3), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_3), index_zmm), - 0xCC); + oxCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); + index_swizzle::template swap_n(index_zmm), + oxAA); return key_zmm; } @@ -255,6 +290,12 @@ template (0xAA); + const auto oxCC = convert_int_to_mask(0xCC); + const auto oxF0 = convert_int_to_mask(0xF0); // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 key_zmm = cmp_merge( @@ -262,21 +303,21 @@ X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_8lanes(reg_t key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_4), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_4), index_zmm), - 0xF0); + oxF0); // 2) half_cleaner[4] key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_3), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_3), index_zmm), - 0xCC); + oxCC); // 3) half_cleaner[1] key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); + index_swizzle::template swap_n(index_zmm), + oxAA); return key_zmm; } @@ -552,9 +593,10 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { keyVecs[i] = keyType::mask_loadu( keyType::zmm_max(), ioMasks[j], keys + i * keyType::numlanes); - valueVecs[i] = valueType::mask_loadu(valueType::zmm_max(), - ioMasks[j], - values + i * valueType::numlanes); + valueVecs[i] = valueType::mask_loadu( + valueType::zmm_max(), + resize_mask(ioMasks[j]), + values + i * valueType::numlanes); } // Sort each loaded vector @@ -577,8 +619,9 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { keyType::mask_storeu( keys + i * keyType::numlanes, ioMasks[j], keyVecs[i]); - valueType::mask_storeu( - values + i * valueType::numlanes, ioMasks[j], valueVecs[i]); + valueType::mask_storeu(values + i * valueType::numlanes, + resize_mask(ioMasks[j]), + valueVecs[i]); } }