Skip to content

Commit 28abc6c

Browse files
committed
Optimise partial sort to use a single pass
The original partial sort function was a combination of a QuickSelect pass followed by a QuickSort pass. This can lead to unnecessary comparisons and hence wasted CPU cycles. This update changes the internal method used by both `avx512_qselect` and `avx512_partial_qsort` so that they can both leverage the same logic. This is achieved by the addition of a new parameter that indicates whether the function should recurse into (and sort) the left partition (i.e. partially sort the front of the array) or simply recurse into the partition containing the desired element (i.e. select/position a single element).
1 parent 4043ecf commit 28abc6c

File tree

6 files changed

+71
-60
lines changed

6 files changed

+71
-60
lines changed

src/avx512-16bit-common.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,12 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
290290
}
291291

292292
template <typename vtype, typename type_t>
293-
static void qselect_16bit_(type_t *arr,
294-
int64_t pos,
295-
int64_t left,
296-
int64_t right,
297-
int64_t max_iters)
293+
static void qselsort_16bit_(type_t *arr,
294+
int64_t pos,
295+
int64_t left,
296+
int64_t right,
297+
int64_t max_iters,
298+
bool sortleft)
298299
{
299300
/*
300301
* Resort to std::sort if quicksort isnt making any progress
@@ -316,10 +317,10 @@ static void qselect_16bit_(type_t *arr,
316317
type_t biggest = vtype::type_min();
317318
int64_t pivot_index = partition_avx512<vtype>(
318319
arr, left, right + 1, pivot, &smallest, &biggest);
319-
if ((pivot != smallest) && (pos < pivot_index))
320-
qselect_16bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
320+
if ((pivot != smallest) && ((pos < pivot_index) || sortleft))
321+
qselsort_16bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1, sortleft);
321322
else if ((pivot != biggest) && (pos >= pivot_index))
322-
qselect_16bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
323+
qselsort_16bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1, sortleft);
323324
}
324325

325326
#endif // AVX512_16BIT_COMMON

src/avx512-16bit-qsort.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -412,32 +412,32 @@ bool is_a_nan<uint16_t>(uint16_t elem)
412412
}
413413

414414
template <>
415-
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
415+
void avx512_qselsort(int16_t *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
416416
{
417417
if (arrsize > 1) {
418-
qselect_16bit_<zmm_vector<int16_t>, int16_t>(
419-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
418+
qselsort_16bit_<zmm_vector<int16_t>, int16_t>(
419+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize), sortleft);
420420
}
421421
}
422422

423423
template <>
424-
void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
424+
void avx512_qselsort(uint16_t *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
425425
{
426426
if (arrsize > 1) {
427-
qselect_16bit_<zmm_vector<uint16_t>, uint16_t>(
428-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
427+
qselsort_16bit_<zmm_vector<uint16_t>, uint16_t>(
428+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize), sortleft);
429429
}
430430
}
431431

432-
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
432+
void avx512_qselsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
433433
{
434434
int64_t indx_last_elem = arrsize - 1;
435435
if (UNLIKELY(hasnan)) {
436436
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
437437
}
438438
if (indx_last_elem >= k) {
439-
qselect_16bit_<zmm_vector<float16>, uint16_t>(
440-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
439+
qselsort_16bit_<zmm_vector<float16>, uint16_t>(
440+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem), sortleft);
441441
}
442442
}
443443

src/avx512-32bit-qsort.hpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -657,11 +657,12 @@ qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
657657
}
658658

659659
template <typename vtype, typename type_t>
660-
static void qselect_32bit_(type_t *arr,
661-
int64_t pos,
662-
int64_t left,
663-
int64_t right,
664-
int64_t max_iters)
660+
static void qselsort_32bit_(type_t *arr,
661+
int64_t pos,
662+
int64_t left,
663+
int64_t right,
664+
int64_t max_iters,
665+
bool sortleft)
665666
{
666667
/*
667668
* Resort to std::sort if quicksort isnt making any progress
@@ -683,10 +684,10 @@ static void qselect_32bit_(type_t *arr,
683684
type_t biggest = vtype::type_min();
684685
int64_t pivot_index = partition_avx512_unrolled<vtype, 2>(
685686
arr, left, right + 1, pivot, &smallest, &biggest);
686-
if ((pivot != smallest) && (pos < pivot_index))
687-
qselect_32bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
687+
if ((pivot != smallest) && ((pos < pivot_index) || sortleft))
688+
qselsort_32bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1, sortleft);
688689
else if ((pivot != biggest) && (pos >= pivot_index))
689-
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
690+
qselsort_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1, sortleft);
690691
}
691692

692693
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
@@ -715,33 +716,33 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
715716
}
716717

717718
template <>
718-
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
719+
void avx512_qselsort<int32_t>(int32_t *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
719720
{
720721
if (arrsize > 1) {
721-
qselect_32bit_<zmm_vector<int32_t>, int32_t>(
722-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
722+
qselsort_32bit_<zmm_vector<int32_t>, int32_t>(
723+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize), sortleft);
723724
}
724725
}
725726

726727
template <>
727-
void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
728+
void avx512_qselsort<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
728729
{
729730
if (arrsize > 1) {
730-
qselect_32bit_<zmm_vector<uint32_t>, uint32_t>(
731-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
731+
qselsort_32bit_<zmm_vector<uint32_t>, uint32_t>(
732+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize), sortleft);
732733
}
733734
}
734735

735736
template <>
736-
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
737+
void avx512_qselsort<float>(float *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
737738
{
738739
int64_t indx_last_elem = arrsize - 1;
739740
if (UNLIKELY(hasnan)) {
740741
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
741742
}
742743
if (indx_last_elem >= k) {
743-
qselect_32bit_<zmm_vector<float>, float>(
744-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
744+
qselsort_32bit_<zmm_vector<float>, float>(
745+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem), sortleft);
745746
}
746747
}
747748

src/avx512-64bit-qsort.hpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -751,11 +751,12 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
751751
}
752752

753753
template <typename vtype, typename type_t>
754-
static void qselect_64bit_(type_t *arr,
755-
int64_t pos,
756-
int64_t left,
757-
int64_t right,
758-
int64_t max_iters)
754+
static void qselsort_64bit_(type_t *arr,
755+
int64_t pos,
756+
int64_t left,
757+
int64_t right,
758+
int64_t max_iters,
759+
bool sortleft)
759760
{
760761
/*
761762
* Resort to std::sort if quicksort isnt making any progress
@@ -777,40 +778,40 @@ static void qselect_64bit_(type_t *arr,
777778
type_t biggest = vtype::type_min();
778779
int64_t pivot_index = partition_avx512_unrolled<vtype, 8>(
779780
arr, left, right + 1, pivot, &smallest, &biggest);
780-
if ((pivot != smallest) && (pos < pivot_index))
781-
qselect_64bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
781+
if ((pivot != smallest) && ((pos < pivot_index) || sortleft))
782+
qselsort_64bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1, sortleft);
782783
else if ((pivot != biggest) && (pos >= pivot_index))
783-
qselect_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
784+
qselsort_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1, sortleft);
784785
}
785786

786787
template <>
787-
void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
788+
void avx512_qselsort<int64_t>(int64_t *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
788789
{
789790
if (arrsize > 1) {
790-
qselect_64bit_<zmm_vector<int64_t>, int64_t>(
791-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
791+
qselsort_64bit_<zmm_vector<int64_t>, int64_t>(
792+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize), sortleft);
792793
}
793794
}
794795

795796
template <>
796-
void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
797+
void avx512_qselsort<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
797798
{
798799
if (arrsize > 1) {
799-
qselect_64bit_<zmm_vector<uint64_t>, uint64_t>(
800-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
800+
qselsort_64bit_<zmm_vector<uint64_t>, uint64_t>(
801+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize), sortleft);
801802
}
802803
}
803804

804805
template <>
805-
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize, bool hasnan)
806+
void avx512_qselsort<double>(double *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
806807
{
807808
int64_t indx_last_elem = arrsize - 1;
808809
if (UNLIKELY(hasnan)) {
809810
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
810811
}
811812
if (indx_last_elem >= k) {
812-
qselect_64bit_<zmm_vector<double>, double>(
813-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
813+
qselsort_64bit_<zmm_vector<double>, double>(
814+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem), sortleft);
814815
}
815816
}
816817

src/avx512-common-qsort.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,27 @@ void avx512_qsort(T *arr, int64_t arrsize);
100100
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize);
101101

102102
template <typename T>
103-
void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
104-
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false);
103+
void avx512_qselsort(T *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan);
104+
void avx512_qselsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan);
105+
106+
template <typename T>
107+
inline void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
108+
{
109+
avx512_qselsort<T>(arr, k, arrsize, false, hasnan);
110+
}
111+
inline void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false)
112+
{
113+
avx512_qselsort_fp16(arr, k, arrsize, false, hasnan);
114+
}
105115

106116
template <typename T>
107117
inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
108118
{
109-
avx512_qselect<T>(arr, k - 1, arrsize, hasnan);
110-
avx512_qsort<T>(arr, k - 1);
119+
avx512_qselsort<T>(arr, k - 1, arrsize, true, hasnan);
111120
}
112121
inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false)
113122
{
114-
avx512_qselect_fp16(arr, k - 1, arrsize, hasnan);
115-
avx512_qsort_fp16(arr, k - 1);
123+
avx512_qselsort_fp16(arr, k - 1, arrsize, true, hasnan);
116124
}
117125

118126
// key-value sort routines

src/avx512fp16-16bit-qsort.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ bool is_a_nan<_Float16>(_Float16 elem)
153153
}
154154

155155
template <>
156-
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan)
156+
void avx512_qselsort(_Float16 *arr, int64_t k, int64_t arrsize, bool sortleft, bool hasnan)
157157
{
158158
int64_t indx_last_elem = arrsize - 1;
159159
if (UNLIKELY(hasnan)) {
160160
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
161161
}
162162
if (indx_last_elem >= k) {
163-
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
164-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
163+
qselsort_16bit_<zmm_vector<_Float16>, _Float16>(
164+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem), sortleft);
165165
}
166166
}
167167

0 commit comments

Comments
 (0)