Skip to content

Commit 42c672f

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #59 from r-devulap/maint-templates
MAINT: Remove template specializations for quicksort methods
2 parents dfbcb09 + 4efed2a commit 42c672f

11 files changed

+370
-482
lines changed

meson.build

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
project('x86-simd-sort', 'cpp',
2-
version : '1.0.0',
3-
license : 'BSD 3-clause')
2+
version : '2.0.0',
3+
license : 'BSD 3-clause',
4+
default_options : ['cpp_std=c++17'])
45
cpp = meson.get_compiler('cpp')
56
src = include_directories('src')
67
bench = include_directories('benchmarks')

src/avx512-16bit-qsort.hpp

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
377377
//return npy_half_to_float(a) < npy_half_to_float(b);
378378
}
379379

380-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
380+
template <>
381+
int64_t replace_nan_with_inf<zmm_vector<float16>>(uint16_t *arr,
381382
int64_t arrsize)
382383
{
383384
int64_t nan_count = 0;
@@ -396,77 +397,66 @@ X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
396397
return nan_count;
397398
}
398399

399-
X86_SIMD_SORT_INLINE void
400-
replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
401-
{
402-
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
403-
arr[ii] = 0xFFFF;
404-
nan_count -= 1;
405-
}
406-
}
407-
408400
template <>
409401
bool is_a_nan<uint16_t>(uint16_t elem)
410402
{
411403
return (elem & 0x7c00) == 0x7c00;
412404
}
413405

406+
/* Specialized template function for 16-bit qsort_ funcs*/
414407
template <>
415-
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
408+
void qsort_<zmm_vector<int16_t>>(int16_t *arr,
409+
int64_t left,
410+
int64_t right,
411+
int64_t maxiters)
416412
{
417-
if (arrsize > 1) {
418-
qselect_16bit_<zmm_vector<int16_t>, int16_t>(
419-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
420-
}
413+
qsort_16bit_<zmm_vector<int16_t>>(arr, left, right, maxiters);
421414
}
422415

423416
template <>
424-
void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
417+
void qsort_<zmm_vector<uint16_t>>(uint16_t *arr,
418+
int64_t left,
419+
int64_t right,
420+
int64_t maxiters)
425421
{
426-
if (arrsize > 1) {
427-
qselect_16bit_<zmm_vector<uint16_t>, uint16_t>(
428-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
429-
}
422+
qsort_16bit_<zmm_vector<uint16_t>>(arr, left, right, maxiters);
430423
}
431424

432-
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
425+
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
433426
{
434-
int64_t indx_last_elem = arrsize - 1;
435-
if (UNLIKELY(hasnan)) {
436-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
437-
}
438-
if (indx_last_elem >= k) {
439-
qselect_16bit_<zmm_vector<float16>, uint16_t>(
440-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
427+
if (arrsize > 1) {
428+
int64_t nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(
429+
arr, arrsize);
430+
qsort_16bit_<zmm_vector<float16>, uint16_t>(
431+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
432+
replace_inf_with_nan(arr, arrsize, nan_count);
441433
}
442434
}
443435

436+
/* Specialized template function for 16-bit qselect_ funcs*/
444437
template <>
445-
void avx512_qsort(int16_t *arr, int64_t arrsize)
438+
void qselect_<zmm_vector<int16_t>>(
439+
int16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
446440
{
447-
if (arrsize > 1) {
448-
qsort_16bit_<zmm_vector<int16_t>, int16_t>(
449-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
450-
}
441+
qselect_16bit_<zmm_vector<int16_t>>(arr, k, left, right, maxiters);
451442
}
452443

453444
template <>
454-
void avx512_qsort(uint16_t *arr, int64_t arrsize)
445+
void qselect_<zmm_vector<uint16_t>>(
446+
uint16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
455447
{
456-
if (arrsize > 1) {
457-
qsort_16bit_<zmm_vector<uint16_t>, uint16_t>(
458-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
459-
}
448+
qselect_16bit_<zmm_vector<uint16_t>>(arr, k, left, right, maxiters);
460449
}
461450

462-
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
451+
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
463452
{
464-
if (arrsize > 1) {
465-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
466-
qsort_16bit_<zmm_vector<float16>, uint16_t>(
467-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
468-
replace_inf_with_nan(arr, arrsize, nan_count);
453+
int64_t indx_last_elem = arrsize - 1;
454+
if (UNLIKELY(hasnan)) {
455+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
456+
}
457+
if (indx_last_elem >= k) {
458+
qselect_16bit_<zmm_vector<float16>, uint16_t>(
459+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
469460
}
470461
}
471-
472462
#endif // AVX512_QSORT_16BIT

src/avx512-32bit-qsort.hpp

Lines changed: 39 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,15 @@ struct zmm_vector<float> {
256256
{
257257
return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
258258
}
259+
static opmask_t get_partial_loadmask(int size)
260+
{
261+
return (0x0001 << size) - 0x0001;
262+
}
263+
template <int type>
264+
static opmask_t fpclass(zmm_t x)
265+
{
266+
return _mm512_fpclass_ps_mask(x, type);
267+
}
259268
template <int scale>
260269
static ymm_t i64gather(__m512i index, void const *base)
261270
{
@@ -279,6 +288,10 @@ struct zmm_vector<float> {
279288
{
280289
return _mm512_mask_compressstoreu_ps(mem, mask, x);
281290
}
291+
static zmm_t maskz_loadu(opmask_t mask, void const *mem)
292+
{
293+
return _mm512_maskz_loadu_ps(mask, mem);
294+
}
282295
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
283296
{
284297
return _mm512_mask_loadu_ps(x, mask, mem);
@@ -689,95 +702,53 @@ static void qselect_32bit_(type_t *arr,
689702
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
690703
}
691704

692-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
693-
{
694-
int64_t nan_count = 0;
695-
__mmask16 loadmask = 0xFFFF;
696-
while (arrsize > 0) {
697-
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
698-
__m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr);
699-
__mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
700-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
701-
_mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT);
702-
arr += 16;
703-
arrsize -= 16;
704-
}
705-
return nan_count;
706-
}
707-
708-
X86_SIMD_SORT_INLINE void
709-
replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
710-
{
711-
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
712-
arr[ii] = std::nanf("1");
713-
nan_count -= 1;
714-
}
715-
}
716-
705+
/* Specialized template function for 32-bit qselect_ funcs*/
717706
template <>
718-
void avx512_qselect<int32_t>(int32_t *arr,
719-
int64_t k,
720-
int64_t arrsize,
721-
bool hasnan)
707+
void qselect_<zmm_vector<int32_t>>(
708+
int32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
722709
{
723-
if (arrsize > 1) {
724-
qselect_32bit_<zmm_vector<int32_t>, int32_t>(
725-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
726-
}
710+
qselect_32bit_<zmm_vector<int32_t>>(arr, k, left, right, maxiters);
727711
}
728712

729713
template <>
730-
void avx512_qselect<uint32_t>(uint32_t *arr,
731-
int64_t k,
732-
int64_t arrsize,
733-
bool hasnan)
714+
void qselect_<zmm_vector<uint32_t>>(
715+
uint32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
734716
{
735-
if (arrsize > 1) {
736-
qselect_32bit_<zmm_vector<uint32_t>, uint32_t>(
737-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
738-
}
717+
qselect_32bit_<zmm_vector<uint32_t>>(arr, k, left, right, maxiters);
739718
}
740719

741720
template <>
742-
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
721+
void qselect_<zmm_vector<float>>(
722+
float *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
743723
{
744-
int64_t indx_last_elem = arrsize - 1;
745-
if (UNLIKELY(hasnan)) {
746-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
747-
}
748-
if (indx_last_elem >= k) {
749-
qselect_32bit_<zmm_vector<float>, float>(
750-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
751-
}
724+
qselect_32bit_<zmm_vector<float>>(arr, k, left, right, maxiters);
752725
}
753726

727+
/* Specialized template function for 32-bit qsort_ funcs*/
754728
template <>
755-
void avx512_qsort<int32_t>(int32_t *arr, int64_t arrsize)
729+
void qsort_<zmm_vector<int32_t>>(int32_t *arr,
730+
int64_t left,
731+
int64_t right,
732+
int64_t maxiters)
756733
{
757-
if (arrsize > 1) {
758-
qsort_32bit_<zmm_vector<int32_t>, int32_t>(
759-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
760-
}
734+
qsort_32bit_<zmm_vector<int32_t>>(arr, left, right, maxiters);
761735
}
762736

763737
template <>
764-
void avx512_qsort<uint32_t>(uint32_t *arr, int64_t arrsize)
738+
void qsort_<zmm_vector<uint32_t>>(uint32_t *arr,
739+
int64_t left,
740+
int64_t right,
741+
int64_t maxiters)
765742
{
766-
if (arrsize > 1) {
767-
qsort_32bit_<zmm_vector<uint32_t>, uint32_t>(
768-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
769-
}
743+
qsort_32bit_<zmm_vector<uint32_t>>(arr, left, right, maxiters);
770744
}
771745

772746
template <>
773-
void avx512_qsort<float>(float *arr, int64_t arrsize)
747+
void qsort_<zmm_vector<float>>(float *arr,
748+
int64_t left,
749+
int64_t right,
750+
int64_t maxiters)
774751
{
775-
if (arrsize > 1) {
776-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
777-
qsort_32bit_<zmm_vector<float>, float>(
778-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
779-
replace_inf_with_nan(arr, arrsize, nan_count);
780-
}
752+
qsort_32bit_<zmm_vector<float>>(arr, left, right, maxiters);
781753
}
782-
783754
#endif //AVX512_QSORT_32BIT

0 commit comments

Comments
 (0)