@@ -732,7 +732,7 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
732
732
return ;
733
733
}
734
734
/*
735
- * Base case: use bitonic networks to sort arrays <= 128
735
+ * Base case: use bitonic networks to sort arrays <= 256
736
736
*/
737
737
if (right + 1 - left <= 256 ) {
738
738
sort_256_64bit<vtype>(arr + left, (int32_t )(right + 1 - left));
@@ -777,10 +777,53 @@ static void qselect_64bit_(type_t *arr,
777
777
type_t biggest = vtype::type_min ();
778
778
int64_t pivot_index = partition_avx512_unrolled<vtype, 8 >(
779
779
arr, left, right + 1 , pivot, &smallest, &biggest);
780
- if ((pivot != smallest) && (pos < pivot_index))
781
- qselect_64bit_<vtype>(arr, pos, left, pivot_index - 1 , max_iters - 1 );
782
- else if ((pivot != biggest) && (pos >= pivot_index))
783
- qselect_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1 );
780
+
781
+ if (pos < pivot_index) {
782
+ if (pivot != smallest)
783
+ qselect_64bit_<vtype>(arr, pos, left, pivot_index - 1 , max_iters - 1 );
784
+ } else {
785
+ if (pivot != biggest)
786
+ qselect_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1 );
787
+ }
788
+ }
789
+
790
+ template <typename vtype, typename type_t >
791
+ static void qselsort_64bit_ (type_t *arr,
792
+ int64_t pos,
793
+ int64_t left,
794
+ int64_t right,
795
+ int64_t max_iters)
796
+ {
797
+ /*
798
+ * Resort to std::sort if quicksort isnt making any progress
799
+ */
800
+ if (max_iters <= 0 ) {
801
+ std::sort (arr + left, arr + right + 1 );
802
+ return ;
803
+ }
804
+ /*
805
+ * Base case: use bitonic networks to sort arrays <= 128
806
+ */
807
+ if (right + 1 - left <= 128 ) {
808
+ sort_128_64bit<vtype>(arr + left, (int32_t )(right + 1 - left));
809
+ return ;
810
+ }
811
+
812
+ type_t pivot = get_pivot_64bit<vtype>(arr, left, right);
813
+ type_t smallest = vtype::type_max ();
814
+ type_t biggest = vtype::type_min ();
815
+ int64_t pivot_index = partition_avx512_unrolled<vtype, 8 >(
816
+ arr, left, right + 1 , pivot, &smallest, &biggest);
817
+
818
+ if (pos < pivot_index) {
819
+ if (pivot != smallest)
820
+ qselsort_64bit_<vtype>(arr, pos, left, pivot_index - 1 , max_iters - 1 );
821
+ } else {
822
+ if (pivot != smallest)
823
+ qsort_64bit_<vtype>(arr, left, pivot_index - 1 , max_iters - 1 );
824
+ if (pivot != biggest)
825
+ qselsort_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1 );
826
+ }
784
827
}
785
828
786
829
template <>
@@ -806,14 +849,45 @@ void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize, bool hasnan
806
849
{
807
850
int64_t indx_last_elem = arrsize - 1 ;
808
851
if (UNLIKELY (hasnan)) {
809
- indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
852
+ indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
810
853
}
811
854
if (indx_last_elem >= k) {
812
855
qselect_64bit_<zmm_vector<double >, double >(
813
856
arr, k, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
814
857
}
815
858
}
816
859
860
+ template <>
861
+ void avx512_partial_qsort<int64_t >(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
862
+ {
863
+ if (arrsize > 1 ) {
864
+ qselsort_64bit_<zmm_vector<int64_t >, int64_t >(
865
+ arr, k - 1 , 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
866
+ }
867
+ }
868
+
869
+ template <>
870
+ void avx512_partial_qsort<uint64_t >(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
871
+ {
872
+ if (arrsize > 1 ) {
873
+ qselsort_64bit_<zmm_vector<uint64_t >, uint64_t >(
874
+ arr, k - 1 , 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
875
+ }
876
+ }
877
+
878
+ template <>
879
+ void avx512_partial_qsort<double >(double *arr, int64_t k, int64_t arrsize, bool hasnan)
880
+ {
881
+ if (LIKELY (k > 0 )) {
882
+ int64_t indx_last_elem = arrsize - 1 ;
883
+ if (UNLIKELY (hasnan)) {
884
+ indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
885
+ }
886
+ qselsort_64bit_<zmm_vector<double >, double >(
887
+ arr, k - 1 , 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
888
+ }
889
+ }
890
+
817
891
template <>
818
892
void avx512_qsort<int64_t >(int64_t *arr, int64_t arrsize)
819
893
{
0 commit comments