diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 0c819946..dbddd7b4 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -316,10 +316,53 @@ static void qselect_16bit_(type_t *arr, type_t biggest = vtype::type_min(); int64_t pivot_index = partition_avx512( arr, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - qselect_16bit_(arr, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - qselect_16bit_(arr, pos, pivot_index, right, max_iters - 1); + + if (pos < pivot_index) { + if (pivot != smallest) + qselect_16bit_(arr, pos, left, pivot_index - 1, max_iters - 1); + } else { + if (pivot != biggest) + qselect_16bit_(arr, pos, pivot_index, right, max_iters - 1); + } +} + +template +static void qselsort_16bit_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std::sort(arr + left, arr + right + 1, comparison_func); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 128 + */ + if (right + 1 - left <= 128) { + sort_128_16bit(arr + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot_16bit(arr, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512( + arr, left, right + 1, pivot, &smallest, &biggest); + + if (pos < pivot_index) { + if (pivot != smallest) + qselsort_16bit_(arr, pos, left, pivot_index - 1, max_iters - 1); + } else { + if (pivot != smallest) + qsort_16bit_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + qselsort_16bit_(arr, pos, pivot_index, right, max_iters - 1); + } } #endif // AVX512_16BIT_COMMON diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 1efcf1e9..68fd4017 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -433,7 +433,7 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_16bit_, uint16_t>( @@ -441,6 +441,36 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) } } +template <> +void avx512_partial_qsort(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (arrsize > 1) { + qselsort_16bit_, int16_t>( + arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_partial_qsort(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (arrsize > 1) { + qselsort_16bit_, uint16_t>( + arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (LIKELY(k > 0)) { + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + qselsort_16bit_, uint16_t>( + arr, k - 1, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + } +} + template <> void avx512_qsort(int16_t *arr, int64_t arrsize) { diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index bfd4a151..0a271367 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -683,10 +683,52 @@ static void qselect_32bit_(type_t *arr, type_t biggest = vtype::type_min(); int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - qselect_32bit_(arr, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - qselect_32bit_(arr, pos, pivot_index, right, max_iters - 1); + if (pos < pivot_index) { + if (pivot != smallest) + qselect_32bit_(arr, pos, left, pivot_index - 1, max_iters - 1); + } else { + if (pivot != biggest) + qselect_32bit_(arr, pos, pivot_index, right, max_iters - 1); + } +} + +template +static void qselsort_32bit_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std::sort(arr + left, arr + right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 128 + */ + if (right + 1 - left <= 128) { + sort_128_32bit(arr + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot_32bit(arr, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest); + + if (pos < pivot_index) { + if (pivot != smallest) + qselsort_32bit_(arr, pos, left, pivot_index - 1, max_iters - 1); + } else { + if (pivot != smallest) + qsort_32bit_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + qselsort_32bit_(arr, pos, pivot_index, right, max_iters - 1); + } } X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize) @@ -737,7 +779,7 @@ void avx512_qselect(float *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_32bit_, float>( @@ -745,6 +787,37 @@ void avx512_qselect(float *arr, int64_t k, int64_t arrsize, bool hasnan) } } +template <> +void avx512_partial_qsort(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (arrsize > 1) { + qselsort_32bit_, int32_t>( + arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_partial_qsort(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (arrsize > 1) { + qselsort_32bit_, uint32_t>( + arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_partial_qsort(float *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (LIKELY(k > 0)) { + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + qselsort_32bit_, float>( + arr, k - 1, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + } +} + template <> void avx512_qsort(int32_t *arr, int64_t arrsize) { diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index aa5d7958..28f9b81d 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -732,7 +732,7 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) return; } /* - * Base case: use bitonic networks to sort arrays <= 128 + * Base case: use bitonic networks to sort arrays <= 256 */ if (right + 1 - left <= 256) { sort_256_64bit(arr + left, (int32_t)(right + 1 - left)); @@ -777,10 +777,53 @@ static void qselect_64bit_(type_t *arr, type_t biggest = vtype::type_min(); int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - qselect_64bit_(arr, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - qselect_64bit_(arr, pos, pivot_index, right, max_iters - 1); + + if (pos < pivot_index) { + if (pivot != smallest) + qselect_64bit_(arr, pos, left, pivot_index - 1, max_iters - 1); + } else { + if (pivot != biggest) + qselect_64bit_(arr, pos, pivot_index, right, max_iters - 1); + } +} + +template +static void qselsort_64bit_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std::sort(arr + left, arr + right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 128 + */ + if (right + 1 - left <= 128) { + sort_128_64bit(arr + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot_64bit(arr, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest); + + if (pos < pivot_index) { + if (pivot != smallest) + qselsort_64bit_(arr, pos, left, pivot_index - 1, max_iters - 1); + } else { + if (pivot != smallest) + qsort_64bit_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + qselsort_64bit_(arr, pos, pivot_index, right, max_iters - 1); + } } template <> @@ -806,7 +849,7 @@ void avx512_qselect(double *arr, int64_t k, int64_t arrsize, bool hasnan { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_64bit_, double>( @@ -814,6 +857,37 @@ void avx512_qselect(double *arr, int64_t k, int64_t arrsize, bool hasnan } } +template <> +void avx512_partial_qsort(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (arrsize > 1) { + qselsort_64bit_, int64_t>( + arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_partial_qsort(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (arrsize > 1) { + qselsort_64bit_, uint64_t>( + arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_partial_qsort(double *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (LIKELY(k > 0)) { + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + qselsort_64bit_, double>( + arr, k - 1, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + } +} + template <> void avx512_qsort(int64_t *arr, int64_t arrsize) { diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 841b4a83..adf11a95 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -104,16 +104,8 @@ void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false); void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false); template -inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) -{ - avx512_qselect(arr, k - 1, arrsize, hasnan); - avx512_qsort(arr, k - 1); -} -inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false) -{ - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); - avx512_qsort_fp16(arr, k - 1); -} +void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false); +void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false); // key-value sort routines template diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 8e87ac29..9582afd9 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -157,7 +157,7 @@ void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_16bit_, _Float16>( @@ -165,6 +165,19 @@ void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan) } } +template <> +void avx512_partial_qsort(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + if (LIKELY(k > 0)) { + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + qselsort_16bit_, _Float16>( + arr, k - 1, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + } +} + template <> void avx512_qsort(_Float16 *arr, int64_t arrsize) {