From ae656cc855d63f758e249bcd3cc59cecfc6f845d Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 16 Apr 2024 10:18:25 -0700 Subject: [PATCH 1/5] Adds support for AVX2 key-value sorting --- lib/x86simdsort-avx2.cpp | 41 ++++- lib/x86simdsort-skx.cpp | 1 + lib/x86simdsort.cpp | 12 +- src/avx2-32bit-qsort.hpp | 9 ++ src/avx512-64bit-common.h | 104 ++++++++++++ ...uesort.hpp => xss-common-keyvaluesort.hpp} | 55 +++++-- src/xss-common-qsort.h | 3 + src/xss-network-keyvaluesort.hpp | 152 +++++++++++------- 8 files changed, 304 insertions(+), 73 deletions(-) rename src/{avx512-64bit-keyvaluesort.hpp => xss-common-keyvaluesort.hpp} (90%) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 2afc4d1d..7a220968 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -1,4 +1,5 @@ // AVX2 specific routines: + #include "x86simdsort-static-incl.h" #include "x86simdsort-internal.h" @@ -33,6 +34,38 @@ return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ } +#define DEFINE_KEYVALUE_METHODS(type) \ + template <> \ + void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \ + { \ + avx2_qsort_kv(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \ + { \ + avx2_qsort_kv(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \ + { \ + avx2_qsort_kv(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \ + { \ + avx2_qsort_kv(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \ + { \ + avx2_qsort_kv(key, val, arrsize, hasnan); \ + } \ + template <> \ + void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \ + { \ + avx2_qsort_kv(key, val, arrsize, hasnan); \ + } + namespace xss { namespace avx2 { DEFINE_ALL_METHODS(uint32_t) @@ -41,5 +74,11 @@ namespace avx2 { DEFINE_ALL_METHODS(uint64_t) DEFINE_ALL_METHODS(int64_t) DEFINE_ALL_METHODS(double) + DEFINE_KEYVALUE_METHODS(uint64_t) + DEFINE_KEYVALUE_METHODS(int64_t) + DEFINE_KEYVALUE_METHODS(double) + DEFINE_KEYVALUE_METHODS(uint32_t) + DEFINE_KEYVALUE_METHODS(int32_t) + DEFINE_KEYVALUE_METHODS(float) } // namespace avx2 -} // namespace xss +} // namespace xss \ No newline at end of file diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 829dd7b8..e51c51ed 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -1,4 +1,5 @@ // SKX specific routines: + #include "x86simdsort-static-incl.h" #include "x86simdsort-internal.h" diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 0c16f148..2f268abc 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -208,12 +208,12 @@ DISPATCH_ALL(argselect, (ISA_LIST("avx512_skx", "avx2"))) #define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \ - DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx"))) \ - DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx"))) + DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx", "avx2"))) \ + DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_KEYVALUE_SORT_FORTYPE(uint64_t) DISPATCH_KEYVALUE_SORT_FORTYPE(int64_t) diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index b8cca7a4..a4c7e003 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -99,6 +99,9 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask){ + return convert_int_to_avx2_mask(intMask); + } static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { @@ -268,6 +271,9 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask){ + return convert_int_to_avx2_mask(intMask); + } static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { @@ -444,6 +450,9 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask){ + return convert_int_to_avx2_mask(intMask); + } static int32_t convert_mask_to_int(opmask_t mask) { return convert_avx2_mask_to_int(mask); diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 8f5fdce9..e36a996e 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -24,6 +24,7 @@ template X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm); struct avx512_64bit_swizzle_ops; +struct avx512_ymm_64bit_swizzle_ops; template <> struct ymm_vector { @@ -33,6 +34,8 @@ struct ymm_vector { using opmask_t = __mmask8; static const uint8_t numlanes = 8; static constexpr simd_type vec_type = simd_type::AVX512; + + using swizzle_ops = avx512_ymm_64bit_swizzle_ops; static type_t type_max() { @@ -212,6 +215,14 @@ struct ymm_vector { const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); return permutexvar(rev_index, ymm); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct ymm_vector { @@ -221,6 +232,8 @@ struct ymm_vector { using opmask_t = __mmask8; static const uint8_t numlanes = 8; static constexpr simd_type vec_type = simd_type::AVX512; + + using swizzle_ops = avx512_ymm_64bit_swizzle_ops; static type_t type_max() { @@ -386,6 +399,14 @@ struct ymm_vector { const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); return permutexvar(rev_index, ymm); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct ymm_vector { @@ -395,6 +416,8 @@ struct ymm_vector { using opmask_t = __mmask8; static const uint8_t numlanes = 8; static constexpr simd_type vec_type = simd_type::AVX512; + + using swizzle_ops = avx512_ymm_64bit_swizzle_ops; static type_t type_max() { @@ -560,6 +583,14 @@ struct ymm_vector { const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); return permutexvar(rev_index, ymm); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -1215,4 +1246,77 @@ struct avx512_64bit_swizzle_ops { } }; +struct avx512_ymm_64bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b10110001); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 4) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b01001110); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 8) { + v = _mm256_permute2x128_si256(v, v, 0b00000001); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + constexpr uint64_t mask = 0b00011011; + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, mask); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 8) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m256i v1 = vtype::cast_to(reg); + __m256i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm256_blend_epi32(v1, v2, 0b01010101); + } + else if constexpr (scale == 4) { + v1 = _mm256_blend_epi32(v1, v2, 0b00110011); + } + else if constexpr (scale == 8) { + v1 = _mm256_blend_epi32(v1, v2, 0b00001111); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; + #endif diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp similarity index 90% rename from src/avx512-64bit-keyvaluesort.hpp rename to src/xss-common-keyvaluesort.hpp index 9c4f1459..8e935e53 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -8,7 +8,8 @@ #ifndef AVX512_QSORT_64BIT_KV #define AVX512_QSORT_64BIT_KV -#include "avx512-64bit-common.h" + +#include "xss-common-qsort.h" #include "xss-network-keyvaluesort.hpp" /* @@ -33,15 +34,10 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t1 *keys, { /* which elements are larger than the pivot */ typename vtype1::opmask_t gt_mask = vtype1::ge(keys_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - vtype1::mask_compressstoreu( - keys + left, vtype1::knot_opmask(gt_mask), keys_vec); - vtype1::mask_compressstoreu( - keys + right - amount_gt_pivot, gt_mask, keys_vec); - vtype2::mask_compressstoreu( - indexes + left, vtype2::knot_opmask(gt_mask), indexes_vec); - vtype2::mask_compressstoreu( - indexes + right - amount_gt_pivot, gt_mask, indexes_vec); + + int32_t amount_gt_pivot = vtype1::double_compressstore(keys + left, keys + right - vtype1::numlanes, gt_mask, keys_vec); + vtype2::double_compressstore(indexes + left, indexes + right - vtype2::numlanes, resize_mask(gt_mask), indexes_vec); + *smallest_vec = vtype1::min(keys_vec, *smallest_vec); *biggest_vec = vtype1::max(keys_vec, *biggest_vec); return amount_gt_pivot; @@ -441,4 +437,43 @@ avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) replace_inf_with_nan(keys, arrsize, nan_count); } } + +template +X86_SIMD_SORT_INLINE void +avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) +{ + using keytype = + typename std::conditional, + avx2_vector>::type; + using valtype = + typename std::conditional, + avx2_vector>::type; + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + arrsize_t nan_count = 0; + if (UNLIKELY(hasnan)) { + nan_count = replace_nan_with_inf>(keys, arrsize); + } + qsort_64bit_(keys, + indexes, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize)); + replace_inf_with_nan(keys, arrsize, nan_count); + } + else { + UNUSED(hasnan); + qsort_64bit_(keys, + indexes, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize)); + } + } +} #endif // AVX512_QSORT_64BIT_KV diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 2d5b4ea1..0ba04f07 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -189,6 +189,9 @@ X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b) b = vtype::max(temp, b); } +template +typename vtype::opmask_t convert_int_to_mask(maskType mask); + template diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index bb9b9bcd..afb68a0c 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -1,6 +1,19 @@ #ifndef XSS_KEYVALUE_NETWORKS #define XSS_KEYVALUE_NETWORKS +#include "xss-common-includes.h" + +template +typename vtype::opmask_t convert_int_to_mask(maskType mask){ + if constexpr (vtype::vec_type == simd_type::AVX512) { + return mask; + }else if constexpr (vtype::vec_type == simd_type::AVX2){ + return vtype::convert_int_to_mask(mask); + }else{ + static_assert(always_false, "Error in func convert_int_to_mask"); + } +} + template typename valueType::opmask_t resize_mask(typename keyType::opmask_t mask) { @@ -70,66 +83,74 @@ template (0xAAAA); + const auto oxCCCC = convert_int_to_mask(0xCCCC); + const auto oxFF00 = convert_int_to_mask(0xFF00); + const auto oxF0F0 = convert_int_to_mask(0xF0F0); + key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xCCCC); + index_swizzle::template reverse_n(index_zmm), + oxCCCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_3), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_3), index_zmm), - 0xF0F0); + oxF0F0); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xCCCC); + index_swizzle::template swap_n(index_zmm), + oxCCCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_5), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_5), index_zmm), - 0xFF00); + oxFF00); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_6), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_6), index_zmm), - 0xF0F0); + oxF0F0); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xCCCC); + index_swizzle::template swap_n(index_zmm), + oxCCCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); return key_zmm; } @@ -141,30 +162,38 @@ template (0xAAAA); + const auto oxCCCC = convert_int_to_mask(0xCCCC); + const auto oxFF00 = convert_int_to_mask(0xFF00); + const auto oxF0F0 = convert_int_to_mask(0xF0F0); + key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_7), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_7), index_zmm), - 0xFF00); + oxFF00); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_6), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_6), index_zmm), - 0xF0F0); + oxF0F0); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xCCCC); + index_swizzle::template swap_n(index_zmm), + oxCCCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template reverse_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAAAA); + index_swizzle::template reverse_n(index_zmm), + oxAAAA); return key_zmm; } @@ -173,45 +202,50 @@ template X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t key_zmm, index_type &index_zmm) -{ - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const auto oxAA = convert_int_to_mask(0xAA); + const auto oxCC = convert_int_to_mask(0xCC); + const auto oxF0 = convert_int_to_mask(0xF0); + key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); + index_swizzle::template swap_n(index_zmm), + oxAA); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_1), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_1), index_zmm), - 0xCC); + oxCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); + index_swizzle::template swap_n(index_zmm), + oxAA); key_zmm = cmp_merge( key_zmm, - vtype1::permutexvar(rev_index1, key_zmm), + vtype1::reverse(key_zmm), index_zmm, - vtype2::permutexvar(rev_index2, index_zmm), - 0xF0); + vtype2::reverse(index_zmm), + oxF0); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_3), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_3), index_zmm), - 0xCC); + oxCC); key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); + index_swizzle::template swap_n(index_zmm), + oxAA); return key_zmm; } @@ -255,6 +289,12 @@ template (0xAA); + const auto oxCC = convert_int_to_mask(0xCC); + const auto oxF0 = convert_int_to_mask(0xF0); // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 key_zmm = cmp_merge( @@ -262,21 +302,21 @@ X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_8lanes(reg_t key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_4), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_4), index_zmm), - 0xF0); + oxF0); // 2) half_cleaner[4] key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_3), key_zmm), index_zmm, vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_3), index_zmm), - 0xCC); + oxCC); // 3) half_cleaner[1] key_zmm = cmp_merge( key_zmm, - vtype1::template shuffle(key_zmm), + key_swizzle::template swap_n(key_zmm), index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); + index_swizzle::template swap_n(index_zmm), + oxAA); return key_zmm; } @@ -553,7 +593,7 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, keyVecs[i] = keyType::mask_loadu( keyType::zmm_max(), ioMasks[j], keys + i * keyType::numlanes); valueVecs[i] = valueType::mask_loadu(valueType::zmm_max(), - ioMasks[j], + resize_mask(ioMasks[j]), values + i * valueType::numlanes); } @@ -578,7 +618,7 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, keyType::mask_storeu( keys + i * keyType::numlanes, ioMasks[j], keyVecs[i]); valueType::mask_storeu( - values + i * valueType::numlanes, ioMasks[j], valueVecs[i]); + values + i * valueType::numlanes, resize_mask(ioMasks[j]), valueVecs[i]); } } From 5916cfe12a68591b0def6a395560c201213d6e4f Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 19 Apr 2024 13:04:38 -0700 Subject: [PATCH 2/5] clang-format --- src/avx2-32bit-qsort.hpp | 9 ++++-- src/xss-common-keyvaluesort.hpp | 15 ++++++---- src/xss-network-keyvaluesort.hpp | 51 +++++++++++++++++--------------- 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index a4c7e003..56ffc232 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -99,7 +99,8 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask(mask); } - static opmask_t convert_int_to_mask(uint64_t intMask){ + static opmask_t convert_int_to_mask(uint64_t intMask) + { return convert_int_to_avx2_mask(intMask); } static ymmi_t @@ -271,7 +272,8 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask(mask); } - static opmask_t convert_int_to_mask(uint64_t intMask){ + static opmask_t convert_int_to_mask(uint64_t intMask) + { return convert_int_to_avx2_mask(intMask); } static ymmi_t @@ -450,7 +452,8 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask(mask); } - static opmask_t convert_int_to_mask(uint64_t intMask){ + static opmask_t convert_int_to_mask(uint64_t intMask) + { return convert_int_to_avx2_mask(intMask); } static int32_t convert_mask_to_int(opmask_t mask) diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 8e935e53..8b2dc8e2 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -34,10 +34,14 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t1 *keys, { /* which elements are larger than the pivot */ typename vtype1::opmask_t gt_mask = vtype1::ge(keys_vec, pivot_vec); - - int32_t amount_gt_pivot = vtype1::double_compressstore(keys + left, keys + right - vtype1::numlanes, gt_mask, keys_vec); - vtype2::double_compressstore(indexes + left, indexes + right - vtype2::numlanes, resize_mask(gt_mask), indexes_vec); - + + int32_t amount_gt_pivot = vtype1::double_compressstore( + keys + left, keys + right - vtype1::numlanes, gt_mask, keys_vec); + vtype2::double_compressstore(indexes + left, + indexes + right - vtype2::numlanes, + resize_mask(gt_mask), + indexes_vec); + *smallest_vec = vtype1::min(keys_vec, *smallest_vec); *biggest_vec = vtype1::max(keys_vec, *biggest_vec); return amount_gt_pivot; @@ -457,7 +461,8 @@ avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) if constexpr (std::is_floating_point_v) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf>(keys, arrsize); + nan_count + = replace_nan_with_inf>(keys, arrsize); } qsort_64bit_(keys, indexes, diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index afb68a0c..3cb30378 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -4,13 +4,15 @@ #include "xss-common-includes.h" template -typename vtype::opmask_t convert_int_to_mask(maskType mask){ - if constexpr (vtype::vec_type == simd_type::AVX512) { - return mask; - }else if constexpr (vtype::vec_type == simd_type::AVX2){ +typename vtype::opmask_t convert_int_to_mask(maskType mask) +{ + if constexpr (vtype::vec_type == simd_type::AVX512) { return mask; } + else if constexpr (vtype::vec_type == simd_type::AVX2) { return vtype::convert_int_to_mask(mask); - }else{ - static_assert(always_false, "Error in func convert_int_to_mask"); + } + else { + static_assert(always_false, + "Error in func convert_int_to_mask"); } } @@ -85,12 +87,12 @@ X86_SIMD_SORT_INLINE reg_t sort_reg_16lanes(reg_t key_zmm, { using key_swizzle = typename vtype1::swizzle_ops; using index_swizzle = typename vtype2::swizzle_ops; - + const auto oxAAAA = convert_int_to_mask(0xAAAA); const auto oxCCCC = convert_int_to_mask(0xCCCC); const auto oxFF00 = convert_int_to_mask(0xFF00); const auto oxF0F0 = convert_int_to_mask(0xF0F0); - + key_zmm = cmp_merge( key_zmm, key_swizzle::template reverse_n(key_zmm), @@ -164,12 +166,12 @@ X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_16lanes(reg_t key_zmm, { using key_swizzle = typename vtype1::swizzle_ops; using index_swizzle = typename vtype2::swizzle_ops; - + const auto oxAAAA = convert_int_to_mask(0xAAAA); const auto oxCCCC = convert_int_to_mask(0xCCCC); const auto oxFF00 = convert_int_to_mask(0xFF00); const auto oxF0F0 = convert_int_to_mask(0xF0F0); - + key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_7), key_zmm), @@ -202,14 +204,14 @@ template X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t key_zmm, index_type &index_zmm) -{ +{ using key_swizzle = typename vtype1::swizzle_ops; using index_swizzle = typename vtype2::swizzle_ops; const auto oxAA = convert_int_to_mask(0xAA); const auto oxCC = convert_int_to_mask(0xCC); const auto oxF0 = convert_int_to_mask(0xF0); - + key_zmm = cmp_merge( key_zmm, key_swizzle::template swap_n(key_zmm), @@ -228,12 +230,11 @@ X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t key_zmm, index_type &index_zmm) index_zmm, index_swizzle::template swap_n(index_zmm), oxAA); - key_zmm = cmp_merge( - key_zmm, - vtype1::reverse(key_zmm), - index_zmm, - vtype2::reverse(index_zmm), - oxF0); + key_zmm = cmp_merge(key_zmm, + vtype1::reverse(key_zmm), + index_zmm, + vtype2::reverse(index_zmm), + oxF0); key_zmm = cmp_merge( key_zmm, vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_3), key_zmm), @@ -291,7 +292,7 @@ X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_8lanes(reg_t key_zmm, { using key_swizzle = typename vtype1::swizzle_ops; using index_swizzle = typename vtype2::swizzle_ops; - + const auto oxAA = convert_int_to_mask(0xAA); const auto oxCC = convert_int_to_mask(0xCC); const auto oxF0 = convert_int_to_mask(0xF0); @@ -592,9 +593,10 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { keyVecs[i] = keyType::mask_loadu( keyType::zmm_max(), ioMasks[j], keys + i * keyType::numlanes); - valueVecs[i] = valueType::mask_loadu(valueType::zmm_max(), - resize_mask(ioMasks[j]), - values + i * valueType::numlanes); + valueVecs[i] = valueType::mask_loadu( + valueType::zmm_max(), + resize_mask(ioMasks[j]), + values + i * valueType::numlanes); } // Sort each loaded vector @@ -617,8 +619,9 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { keyType::mask_storeu( keys + i * keyType::numlanes, ioMasks[j], keyVecs[i]); - valueType::mask_storeu( - values + i * valueType::numlanes, resize_mask(ioMasks[j]), valueVecs[i]); + valueType::mask_storeu(values + i * valueType::numlanes, + resize_mask(ioMasks[j]), + valueVecs[i]); } } From 7b6681679ac5d7023ec070ab776eb9230803bbfa Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 19 Apr 2024 16:38:32 -0700 Subject: [PATCH 3/5] Cleaning up some code and de-duplicating some logic --- src/avx512-64bit-common.h | 6 +- src/xss-common-keyvaluesort.hpp | 99 +++++++++++++++------------------ src/xss-common-qsort.h | 3 - 3 files changed, 47 insertions(+), 61 deletions(-) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index e36a996e..6e2b6eff 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -34,7 +34,7 @@ struct ymm_vector { using opmask_t = __mmask8; static const uint8_t numlanes = 8; static constexpr simd_type vec_type = simd_type::AVX512; - + using swizzle_ops = avx512_ymm_64bit_swizzle_ops; static type_t type_max() @@ -232,7 +232,7 @@ struct ymm_vector { using opmask_t = __mmask8; static const uint8_t numlanes = 8; static constexpr simd_type vec_type = simd_type::AVX512; - + using swizzle_ops = avx512_ymm_64bit_swizzle_ops; static type_t type_max() @@ -416,7 +416,7 @@ struct ymm_vector { using opmask_t = __mmask8; static const uint8_t numlanes = 8; static constexpr simd_type vec_type = simd_type::AVX512; - + using swizzle_ops = avx512_ymm_64bit_swizzle_ops; static type_t type_max() diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 8b2dc8e2..30403ff8 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -363,11 +363,11 @@ template -X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys, - type2_t *indexes, - arrsize_t left, - arrsize_t right, - int max_iters) +X86_SIMD_SORT_INLINE void kvsort_(type1_t *keys, + type2_t *indexes, + arrsize_t left, + arrsize_t right, + int max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -393,32 +393,35 @@ X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys, arrsize_t pivot_index = kvpartition_unrolled( keys, indexes, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) { - qsort_64bit_( + kvsort_( keys, indexes, left, pivot_index - 1, max_iters - 1); } if (pivot != biggest) { - qsort_64bit_( + kvsort_( keys, indexes, pivot_index, right, max_iters - 1); } } -template +template + typename full_vector, + template + typename half_vector> X86_SIMD_SORT_INLINE void -avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) +xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan) { using keytype = typename std::conditional, - zmm_vector>::type; + half_vector, + full_vector>::type; using valtype = typename std::conditional, - zmm_vector>::type; -/* - * Enable testing the heapsort key-value sort in the CI: - */ + half_vector, + full_vector>::type; + #ifdef XSS_TEST_KEYVALUE_BASE_CASE int maxiters = -1; bool minarrsize = true; @@ -428,57 +431,43 @@ avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) #endif // XSS_TEST_KEYVALUE_BASE_CASE if (minarrsize) { - arrsize_t nan_count = 0; - if constexpr (xss::fp::is_floating_point_v) { + if constexpr (std::is_floating_point_v) { + arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf>(keys, arrsize); + nan_count + = replace_nan_with_inf>(keys, arrsize); } + kvsort_(keys, + indexes, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize)); + replace_inf_with_nan(keys, arrsize, nan_count); } else { UNUSED(hasnan); + kvsort_(keys, + indexes, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize)); } - qsort_64bit_(keys, indexes, 0, arrsize - 1, maxiters); - replace_inf_with_nan(keys, arrsize, nan_count); } } template X86_SIMD_SORT_INLINE void -avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) +avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) { - using keytype = - typename std::conditional, - avx2_vector>::type; - using valtype = - typename std::conditional, - avx2_vector>::type; + xss_qsort_kv( + keys, indexes, arrsize, hasnan); +} - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - arrsize_t nan_count = 0; - if (UNLIKELY(hasnan)) { - nan_count - = replace_nan_with_inf>(keys, arrsize); - } - qsort_64bit_(keys, - indexes, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize)); - replace_inf_with_nan(keys, arrsize, nan_count); - } - else { - UNUSED(hasnan); - qsort_64bit_(keys, - indexes, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize)); - } - } +template +X86_SIMD_SORT_INLINE void +avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) +{ + xss_qsort_kv( + keys, indexes, arrsize, hasnan); } #endif // AVX512_QSORT_64BIT_KV diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 0ba04f07..2d5b4ea1 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -189,9 +189,6 @@ X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b) b = vtype::max(temp, b); } -template -typename vtype::opmask_t convert_int_to_mask(maskType mask); - template From c55ab7a4ba35bfb49de142b0bbecfea63572783e Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 23 Apr 2024 11:55:40 -0700 Subject: [PATCH 4/5] Rebased onto recent changes --- lib/x86simdsort-avx2.cpp | 12 ++++++------ src/x86simdsort-static-incl.h | 16 +++++++--------- src/xss-common-keyvaluesort.hpp | 13 ++----------- 3 files changed, 15 insertions(+), 26 deletions(-) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 7a220968..4c1123e4 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -38,32 +38,32 @@ template <> \ void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \ { \ - avx2_qsort_kv(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ } \ template <> \ void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \ { \ - avx2_qsort_kv(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ } \ template <> \ void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \ { \ - avx2_qsort_kv(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ } \ template <> \ void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \ { \ - avx2_qsort_kv(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ } \ template <> \ void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \ { \ - avx2_qsort_kv(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ } \ template <> \ void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \ { \ - avx2_qsort_kv(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ } namespace xss { diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 0d0e4400..1f849004 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -100,6 +100,12 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false); std::iota(indices.begin(), indices.end(), 0); \ x86simdsortStatic::argselect(arr, indices.data(), k, size, hasnan); \ return indices; \ + } \ + template \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_qsort( \ + T1 *key, T2 *val, size_t size, bool hasnan) \ + { \ + ISA##_qsort_kv(key, val, size, hasnan); \ } /* @@ -107,13 +113,13 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false); */ #include "xss-common-qsort.h" #include "xss-common-argsort.h" +#include "xss-common-keyvaluesort.hpp" #if defined(__AVX512DQ__) && defined(__AVX512VL__) /* 32-bit and 64-bit dtypes vector definitions on SKX */ #include "avx512-32bit-qsort.hpp" #include "avx512-64bit-qsort.hpp" #include "avx512-64bit-argsort.hpp" -#include "avx512-64bit-keyvaluesort.hpp" /* 16-bit dtypes vector definitions on ICL */ #if defined(__AVX512BW__) && defined(__AVX512VBMI2__) @@ -126,14 +132,6 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false); XSS_METHODS(avx512) -// key-value currently only on avx512 -template -X86_SIMD_SORT_FINLINE void -x86simdsortStatic::keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan) -{ - avx512_qsort_kv(key, val, size, hasnan); -} - #elif defined(__AVX512F__) #error "x86simdsort requires AVX512DQ and AVX512VL to be enabled in addition to AVX512F to use AVX512" diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 30403ff8..de62d317 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -8,7 +8,6 @@ #ifndef AVX512_QSORT_64BIT_KV #define AVX512_QSORT_64BIT_KV - #include "xss-common-qsort.h" #include "xss-network-keyvaluesort.hpp" @@ -437,20 +436,12 @@ xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan) nan_count = replace_nan_with_inf>(keys, arrsize); } - kvsort_(keys, - indexes, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize)); + kvsort_(keys, indexes, 0, arrsize - 1, maxiters); replace_inf_with_nan(keys, arrsize, nan_count); } else { UNUSED(hasnan); - kvsort_(keys, - indexes, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize)); + kvsort_(keys, indexes, 0, arrsize - 1, maxiters); } } } From 36077526b8345eb4879ca895238a3820fbc11cdd Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 6 May 2024 12:42:00 -0700 Subject: [PATCH 5/5] Minor code format --- src/xss-common-keyvaluesort.hpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index de62d317..4699b8a1 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -430,19 +430,18 @@ xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan) #endif // XSS_TEST_KEYVALUE_BASE_CASE if (minarrsize) { - if constexpr (std::is_floating_point_v) { - arrsize_t nan_count = 0; + arrsize_t nan_count = 0; + if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf>(keys, arrsize); } - kvsort_(keys, indexes, 0, arrsize - 1, maxiters); - replace_inf_with_nan(keys, arrsize, nan_count); } else { UNUSED(hasnan); - kvsort_(keys, indexes, 0, arrsize - 1, maxiters); } + kvsort_(keys, indexes, 0, arrsize - 1, maxiters); + replace_inf_with_nan(keys, arrsize, nan_count); } }