Skip to content

Commit bc73d86

Browse files
committed
Optimise partial sort to use a single pass
The original partial sort function was a combination of a QuickSelect pass followed by a QuickSort pass. This can lead to unnecessary comparisons and movement of data. This update changes the internal method used by `avx512_partial_qsort` so that it can more efficiently sort or recurse into each partition as required and without needing to take a second sweep over the array. This commit also tweaks the conditional checks in `avx512_qselect` to avoid a redundant comparison.
1 parent 85f4e9c commit bc73d86

File tree

6 files changed

+252
-27
lines changed

6 files changed

+252
-27
lines changed

src/avx512-16bit-common.h

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,53 @@ static void qselect_16bit_(type_t *arr,
316316
type_t biggest = vtype::type_min();
317317
int64_t pivot_index = partition_avx512<vtype>(
318318
arr, left, right + 1, pivot, &smallest, &biggest);
319-
if ((pivot != smallest) && (pos < pivot_index))
320-
qselect_16bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
321-
else if ((pivot != biggest) && (pos >= pivot_index))
322-
qselect_16bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
319+
320+
if (pos < pivot_index) {
321+
if (pivot != smallest)
322+
qselect_16bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
323+
} else {
324+
if (pivot != biggest)
325+
qselect_16bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
326+
}
327+
}
328+
329+
template <typename vtype, typename type_t>
330+
static void qselsort_16bit_(type_t *arr,
331+
int64_t pos,
332+
int64_t left,
333+
int64_t right,
334+
int64_t max_iters)
335+
{
336+
/*
337+
* Resort to std::sort if quicksort isnt making any progress
338+
*/
339+
if (max_iters <= 0) {
340+
std::sort(arr + left, arr + right + 1, comparison_func<vtype>);
341+
return;
342+
}
343+
/*
344+
* Base case: use bitonic networks to sort arrays <= 128
345+
*/
346+
if (right + 1 - left <= 128) {
347+
sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left));
348+
return;
349+
}
350+
351+
type_t pivot = get_pivot_16bit<vtype>(arr, left, right);
352+
type_t smallest = vtype::type_max();
353+
type_t biggest = vtype::type_min();
354+
int64_t pivot_index = partition_avx512<vtype>(
355+
arr, left, right + 1, pivot, &smallest, &biggest);
356+
357+
if (pos < pivot_index) {
358+
if (pivot != smallest)
359+
qselsort_16bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
360+
} else {
361+
if (pivot != smallest)
362+
qsort_16bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
363+
if (pivot != biggest)
364+
qselsort_16bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
365+
}
323366
}
324367

325368
#endif // AVX512_16BIT_COMMON

src/avx512-16bit-qsort.hpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,14 +433,44 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
433433
{
434434
int64_t indx_last_elem = arrsize - 1;
435435
if (UNLIKELY(hasnan)) {
436-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
436+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
437437
}
438438
if (indx_last_elem >= k) {
439439
qselect_16bit_<zmm_vector<float16>, uint16_t>(
440440
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
441441
}
442442
}
443443

444+
template <>
445+
void avx512_partial_qsort(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
446+
{
447+
if (arrsize > 1) {
448+
qselsort_16bit_<zmm_vector<int16_t>, int16_t>(
449+
arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
450+
}
451+
}
452+
453+
template <>
454+
void avx512_partial_qsort(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
455+
{
456+
if (arrsize > 1) {
457+
qselsort_16bit_<zmm_vector<uint16_t>, uint16_t>(
458+
arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
459+
}
460+
}
461+
462+
void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
463+
{
464+
if (LIKELY(k > 0)) {
465+
int64_t indx_last_elem = arrsize - 1;
466+
if (UNLIKELY(hasnan)) {
467+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
468+
}
469+
qselsort_16bit_<zmm_vector<float16>, uint16_t>(
470+
arr, k - 1, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
471+
}
472+
}
473+
444474
template <>
445475
void avx512_qsort(int16_t *arr, int64_t arrsize)
446476
{

src/avx512-32bit-qsort.hpp

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -683,10 +683,52 @@ static void qselect_32bit_(type_t *arr,
683683
type_t biggest = vtype::type_min();
684684
int64_t pivot_index = partition_avx512_unrolled<vtype, 2>(
685685
arr, left, right + 1, pivot, &smallest, &biggest);
686-
if ((pivot != smallest) && (pos < pivot_index))
687-
qselect_32bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
688-
else if ((pivot != biggest) && (pos >= pivot_index))
689-
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
686+
if (pos < pivot_index) {
687+
if (pivot != smallest)
688+
qselect_32bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
689+
} else {
690+
if (pivot != biggest)
691+
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
692+
}
693+
}
694+
695+
template <typename vtype, typename type_t>
696+
static void qselsort_32bit_(type_t *arr,
697+
int64_t pos,
698+
int64_t left,
699+
int64_t right,
700+
int64_t max_iters)
701+
{
702+
/*
703+
* Resort to std::sort if quicksort isnt making any progress
704+
*/
705+
if (max_iters <= 0) {
706+
std::sort(arr + left, arr + right + 1);
707+
return;
708+
}
709+
/*
710+
* Base case: use bitonic networks to sort arrays <= 128
711+
*/
712+
if (right + 1 - left <= 128) {
713+
sort_128_32bit<vtype>(arr + left, (int32_t)(right + 1 - left));
714+
return;
715+
}
716+
717+
type_t pivot = get_pivot_32bit<vtype>(arr, left, right);
718+
type_t smallest = vtype::type_max();
719+
type_t biggest = vtype::type_min();
720+
int64_t pivot_index = partition_avx512_unrolled<vtype, 2>(
721+
arr, left, right + 1, pivot, &smallest, &biggest);
722+
723+
if (pos < pivot_index) {
724+
if (pivot != smallest)
725+
qselsort_32bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
726+
} else {
727+
if (pivot != smallest)
728+
qsort_32bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
729+
if (pivot != biggest)
730+
qselsort_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
731+
}
690732
}
691733

692734
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
@@ -737,14 +779,45 @@ void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
737779
{
738780
int64_t indx_last_elem = arrsize - 1;
739781
if (UNLIKELY(hasnan)) {
740-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
782+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
741783
}
742784
if (indx_last_elem >= k) {
743785
qselect_32bit_<zmm_vector<float>, float>(
744786
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
745787
}
746788
}
747789

790+
template <>
791+
void avx512_partial_qsort<int32_t>(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
792+
{
793+
if (arrsize > 1) {
794+
qselsort_32bit_<zmm_vector<int32_t>, int32_t>(
795+
arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
796+
}
797+
}
798+
799+
template <>
800+
void avx512_partial_qsort<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
801+
{
802+
if (arrsize > 1) {
803+
qselsort_32bit_<zmm_vector<uint32_t>, uint32_t>(
804+
arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
805+
}
806+
}
807+
808+
template <>
809+
void avx512_partial_qsort<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
810+
{
811+
if (LIKELY(k > 0)) {
812+
int64_t indx_last_elem = arrsize - 1;
813+
if (UNLIKELY(hasnan)) {
814+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
815+
}
816+
qselsort_32bit_<zmm_vector<float>, float>(
817+
arr, k - 1, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
818+
}
819+
}
820+
748821
template <>
749822
void avx512_qsort<int32_t>(int32_t *arr, int64_t arrsize)
750823
{

src/avx512-64bit-qsort.hpp

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
732732
return;
733733
}
734734
/*
735-
* Base case: use bitonic networks to sort arrays <= 128
735+
* Base case: use bitonic networks to sort arrays <= 256
736736
*/
737737
if (right + 1 - left <= 256) {
738738
sort_256_64bit<vtype>(arr + left, (int32_t)(right + 1 - left));
@@ -777,10 +777,53 @@ static void qselect_64bit_(type_t *arr,
777777
type_t biggest = vtype::type_min();
778778
int64_t pivot_index = partition_avx512_unrolled<vtype, 8>(
779779
arr, left, right + 1, pivot, &smallest, &biggest);
780-
if ((pivot != smallest) && (pos < pivot_index))
781-
qselect_64bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
782-
else if ((pivot != biggest) && (pos >= pivot_index))
783-
qselect_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
780+
781+
if (pos < pivot_index) {
782+
if (pivot != smallest)
783+
qselect_64bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
784+
} else {
785+
if (pivot != biggest)
786+
qselect_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
787+
}
788+
}
789+
790+
template <typename vtype, typename type_t>
791+
static void qselsort_64bit_(type_t *arr,
792+
int64_t pos,
793+
int64_t left,
794+
int64_t right,
795+
int64_t max_iters)
796+
{
797+
/*
798+
* Resort to std::sort if quicksort isnt making any progress
799+
*/
800+
if (max_iters <= 0) {
801+
std::sort(arr + left, arr + right + 1);
802+
return;
803+
}
804+
/*
805+
* Base case: use bitonic networks to sort arrays <= 128
806+
*/
807+
if (right + 1 - left <= 128) {
808+
sort_128_64bit<vtype>(arr + left, (int32_t)(right + 1 - left));
809+
return;
810+
}
811+
812+
type_t pivot = get_pivot_64bit<vtype>(arr, left, right);
813+
type_t smallest = vtype::type_max();
814+
type_t biggest = vtype::type_min();
815+
int64_t pivot_index = partition_avx512_unrolled<vtype, 8>(
816+
arr, left, right + 1, pivot, &smallest, &biggest);
817+
818+
if (pos < pivot_index) {
819+
if (pivot != smallest)
820+
qselsort_64bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
821+
} else {
822+
if (pivot != smallest)
823+
qsort_64bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
824+
if (pivot != biggest)
825+
qselsort_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
826+
}
784827
}
785828

786829
template <>
@@ -806,14 +849,45 @@ void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize, bool hasnan
806849
{
807850
int64_t indx_last_elem = arrsize - 1;
808851
if (UNLIKELY(hasnan)) {
809-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
852+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
810853
}
811854
if (indx_last_elem >= k) {
812855
qselect_64bit_<zmm_vector<double>, double>(
813856
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
814857
}
815858
}
816859

860+
template <>
861+
void avx512_partial_qsort<int64_t>(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
862+
{
863+
if (arrsize > 1) {
864+
qselsort_64bit_<zmm_vector<int64_t>, int64_t>(
865+
arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
866+
}
867+
}
868+
869+
template <>
870+
void avx512_partial_qsort<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
871+
{
872+
if (arrsize > 1) {
873+
qselsort_64bit_<zmm_vector<uint64_t>, uint64_t>(
874+
arr, k - 1, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
875+
}
876+
}
877+
878+
template <>
879+
void avx512_partial_qsort<double>(double *arr, int64_t k, int64_t arrsize, bool hasnan)
880+
{
881+
if (LIKELY(k > 0)) {
882+
int64_t indx_last_elem = arrsize - 1;
883+
if (UNLIKELY(hasnan)) {
884+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
885+
}
886+
qselsort_64bit_<zmm_vector<double>, double>(
887+
arr, k - 1, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
888+
}
889+
}
890+
817891
template <>
818892
void avx512_qsort<int64_t>(int64_t *arr, int64_t arrsize)
819893
{

src/avx512-common-qsort.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,8 @@ void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
104104
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false);
105105

106106
template <typename T>
107-
inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
108-
{
109-
avx512_qselect<T>(arr, k - 1, arrsize, hasnan);
110-
avx512_qsort<T>(arr, k - 1);
111-
}
112-
inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false)
113-
{
114-
avx512_qselect_fp16(arr, k - 1, arrsize, hasnan);
115-
avx512_qsort_fp16(arr, k - 1);
116-
}
107+
void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
108+
void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false);
117109

118110
// key-value sort routines
119111
template <typename T>

src/avx512fp16-16bit-qsort.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,27 @@ void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan)
157157
{
158158
int64_t indx_last_elem = arrsize - 1;
159159
if (UNLIKELY(hasnan)) {
160-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
160+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
161161
}
162162
if (indx_last_elem >= k) {
163163
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
164164
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
165165
}
166166
}
167167

168+
template <>
169+
void avx512_partial_qsort(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan)
170+
{
171+
if (LIKELY(k > 0)) {
172+
int64_t indx_last_elem = arrsize - 1;
173+
if (UNLIKELY(hasnan)) {
174+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
175+
}
176+
qselsort_16bit_<zmm_vector<_Float16>, _Float16>(
177+
arr, k - 1, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
178+
}
179+
}
180+
168181
template <>
169182
void avx512_qsort(_Float16 *arr, int64_t arrsize)
170183
{

0 commit comments

Comments
 (0)