@@ -256,6 +256,15 @@ struct zmm_vector<float> {
256
256
{
257
257
return _mm512_cmp_ps_mask (x, y, _CMP_GE_OQ);
258
258
}
259
+ static opmask_t get_partial_loadmask (int size)
260
+ {
261
+ return (0x0001 << size) - 0x0001 ;
262
+ }
263
+ template <int type>
264
+ static opmask_t fpclass (zmm_t x)
265
+ {
266
+ return _mm512_fpclass_ps_mask (x, type);
267
+ }
259
268
template <int scale>
260
269
static ymm_t i64gather (__m512i index, void const *base)
261
270
{
@@ -279,6 +288,10 @@ struct zmm_vector<float> {
279
288
{
280
289
return _mm512_mask_compressstoreu_ps (mem, mask, x);
281
290
}
291
+ static zmm_t maskz_loadu (opmask_t mask, void const *mem)
292
+ {
293
+ return _mm512_maskz_loadu_ps (mask, mem);
294
+ }
282
295
static zmm_t mask_loadu (zmm_t x, opmask_t mask, void const *mem)
283
296
{
284
297
return _mm512_mask_loadu_ps (x, mask, mem);
@@ -689,95 +702,53 @@ static void qselect_32bit_(type_t *arr,
689
702
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1 );
690
703
}
691
704
692
- X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf (float *arr, int64_t arrsize)
693
- {
694
- int64_t nan_count = 0 ;
695
- __mmask16 loadmask = 0xFFFF ;
696
- while (arrsize > 0 ) {
697
- if (arrsize < 16 ) { loadmask = (0x0001 << arrsize) - 0x0001 ; }
698
- __m512 in_zmm = _mm512_maskz_loadu_ps (loadmask, arr);
699
- __mmask16 nanmask = _mm512_cmp_ps_mask (in_zmm, in_zmm, _CMP_NEQ_UQ);
700
- nan_count += _mm_popcnt_u32 ((int32_t )nanmask);
701
- _mm512_mask_storeu_ps (arr, nanmask, ZMM_MAX_FLOAT);
702
- arr += 16 ;
703
- arrsize -= 16 ;
704
- }
705
- return nan_count;
706
- }
707
-
708
- X86_SIMD_SORT_INLINE void
709
- replace_inf_with_nan (float *arr, int64_t arrsize, int64_t nan_count)
710
- {
711
- for (int64_t ii = arrsize - 1 ; nan_count > 0 ; --ii) {
712
- arr[ii] = std::nanf (" 1" );
713
- nan_count -= 1 ;
714
- }
715
- }
716
-
705
+ /* Specialized template function for 32-bit qselect_ funcs*/
717
706
template <>
718
- void avx512_qselect<int32_t >(int32_t *arr,
719
- int64_t k,
720
- int64_t arrsize,
721
- bool hasnan)
707
+ void qselect_<zmm_vector<int32_t >>(
708
+ int32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
722
709
{
723
- if (arrsize > 1 ) {
724
- qselect_32bit_<zmm_vector<int32_t >, int32_t >(
725
- arr, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
726
- }
710
+ qselect_32bit_<zmm_vector<int32_t >>(arr, k, left, right, maxiters);
727
711
}
728
712
729
713
template <>
730
- void avx512_qselect<uint32_t >(uint32_t *arr,
731
- int64_t k,
732
- int64_t arrsize,
733
- bool hasnan)
714
+ void qselect_<zmm_vector<uint32_t >>(
715
+ uint32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
734
716
{
735
- if (arrsize > 1 ) {
736
- qselect_32bit_<zmm_vector<uint32_t >, uint32_t >(
737
- arr, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
738
- }
717
+ qselect_32bit_<zmm_vector<uint32_t >>(arr, k, left, right, maxiters);
739
718
}
740
719
741
720
template <>
742
- void avx512_qselect<float >(float *arr, int64_t k, int64_t arrsize, bool hasnan)
721
+ void qselect_<zmm_vector<float >>(
722
+ float *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
743
723
{
744
- int64_t indx_last_elem = arrsize - 1 ;
745
- if (UNLIKELY (hasnan)) {
746
- indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
747
- }
748
- if (indx_last_elem >= k) {
749
- qselect_32bit_<zmm_vector<float >, float >(
750
- arr, k, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
751
- }
724
+ qselect_32bit_<zmm_vector<float >>(arr, k, left, right, maxiters);
752
725
}
753
726
727
+ /* Specialized template function for 32-bit qsort_ funcs*/
754
728
template <>
755
- void avx512_qsort<int32_t >(int32_t *arr, int64_t arrsize)
729
+ void qsort_<zmm_vector<int32_t >>(int32_t *arr,
730
+ int64_t left,
731
+ int64_t right,
732
+ int64_t maxiters)
756
733
{
757
- if (arrsize > 1 ) {
758
- qsort_32bit_<zmm_vector<int32_t >, int32_t >(
759
- arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
760
- }
734
+ qsort_32bit_<zmm_vector<int32_t >>(arr, left, right, maxiters);
761
735
}
762
736
763
737
template <>
764
- void avx512_qsort<uint32_t >(uint32_t *arr, int64_t arrsize)
738
+ void qsort_<zmm_vector<uint32_t >>(uint32_t *arr,
739
+ int64_t left,
740
+ int64_t right,
741
+ int64_t maxiters)
765
742
{
766
- if (arrsize > 1 ) {
767
- qsort_32bit_<zmm_vector<uint32_t >, uint32_t >(
768
- arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
769
- }
743
+ qsort_32bit_<zmm_vector<uint32_t >>(arr, left, right, maxiters);
770
744
}
771
745
772
746
template <>
773
- void avx512_qsort<float >(float *arr, int64_t arrsize)
747
+ void qsort_<zmm_vector<float >>(float *arr,
748
+ int64_t left,
749
+ int64_t right,
750
+ int64_t maxiters)
774
751
{
775
- if (arrsize > 1 ) {
776
- int64_t nan_count = replace_nan_with_inf (arr, arrsize);
777
- qsort_32bit_<zmm_vector<float >, float >(
778
- arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
779
- replace_inf_with_nan (arr, arrsize, nan_count);
780
- }
752
+ qsort_32bit_<zmm_vector<float >>(arr, left, right, maxiters);
781
753
}
782
-
783
754
#endif // AVX512_QSORT_32BIT
0 commit comments