Skip to content

Commit 472c7d0

Browse files
author
Raghuveer Devulapalli
committed
minimize template specialization for avx512_qselect
1 parent 3ddc914 commit 472c7d0

File tree

5 files changed

+69
-98
lines changed

5 files changed

+69
-98
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -403,56 +403,51 @@ bool is_a_nan<uint16_t>(uint16_t elem)
403403
return (elem & 0x7c00) == 0x7c00;
404404
}
405405

406+
/* Specialized template function for 16-bit qsort_ funcs*/
406407
template <>
407-
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
408+
void qsort_<zmm_vector<int16_t>>(int16_t* arr, int64_t left, int64_t right, int64_t maxiters)
408409
{
409-
if (arrsize > 1) {
410-
qselect_16bit_<zmm_vector<int16_t>, int16_t>(
411-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
412-
}
410+
qsort_16bit_<zmm_vector<int16_t>>(arr, left, right, maxiters);
413411
}
414412

415413
template <>
416-
void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
414+
void qsort_<zmm_vector<uint16_t>>(uint16_t* arr, int64_t left, int64_t right, int64_t maxiters)
417415
{
418-
if (arrsize > 1) {
419-
qselect_16bit_<zmm_vector<uint16_t>, uint16_t>(
420-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
421-
}
416+
qsort_16bit_<zmm_vector<uint16_t>>(arr, left, right, maxiters);
422417
}
423418

424-
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
419+
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
425420
{
426-
int64_t indx_last_elem = arrsize - 1;
427-
if (UNLIKELY(hasnan)) {
428-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
429-
}
430-
if (indx_last_elem >= k) {
431-
qselect_16bit_<zmm_vector<float16>, uint16_t>(
432-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
421+
if (arrsize > 1) {
422+
int64_t nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(arr, arrsize);
423+
qsort_16bit_<zmm_vector<float16>, uint16_t>(
424+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
425+
replace_inf_with_nan(arr, arrsize, nan_count);
433426
}
434427
}
435428

429+
/* Specialized template function for 16-bit qselect_ funcs*/
436430
template <>
437-
void qsort_<zmm_vector<int16_t>>(int16_t* arr, int64_t left, int64_t right, int64_t maxiters)
431+
void qselect_<zmm_vector<int16_t>>(int16_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
438432
{
439-
qsort_16bit_<zmm_vector<int16_t>>(arr, left, right, maxiters);
433+
qselect_16bit_<zmm_vector<int16_t>>(arr, k, left, right, maxiters);
440434
}
441435

442436
template <>
443-
void qsort_<zmm_vector<uint16_t>>(uint16_t* arr, int64_t left, int64_t right, int64_t maxiters)
437+
void qselect_<zmm_vector<uint16_t>>(uint16_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
444438
{
445-
qsort_16bit_<zmm_vector<uint16_t>>(arr, left, right, maxiters);
439+
qselect_16bit_<zmm_vector<uint16_t>>(arr, k, left, right, maxiters);
446440
}
447441

448-
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
442+
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
449443
{
450-
if (arrsize > 1) {
451-
int64_t nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(arr, arrsize);
452-
qsort_16bit_<zmm_vector<float16>, uint16_t>(
453-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
454-
replace_inf_with_nan(arr, arrsize, nan_count);
444+
int64_t indx_last_elem = arrsize - 1;
445+
if (UNLIKELY(hasnan)) {
446+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
447+
}
448+
if (indx_last_elem >= k) {
449+
qselect_16bit_<zmm_vector<float16>, uint16_t>(
450+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
455451
}
456452
}
457-
458453
#endif // AVX512_QSORT_16BIT

src/avx512-32bit-qsort.hpp

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -702,43 +702,26 @@ static void qselect_32bit_(type_t *arr,
702702
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
703703
}
704704

705+
/* Specialized template function for 32-bit qselect_ funcs*/
705706
template <>
706-
void avx512_qselect<int32_t>(int32_t *arr,
707-
int64_t k,
708-
int64_t arrsize,
709-
bool hasnan)
707+
void qselect_<zmm_vector<int32_t>>(int32_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
710708
{
711-
if (arrsize > 1) {
712-
qselect_32bit_<zmm_vector<int32_t>, int32_t>(
713-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
714-
}
709+
qselect_32bit_<zmm_vector<int32_t>>(arr, k, left, right, maxiters);
715710
}
716711

717712
template <>
718-
void avx512_qselect<uint32_t>(uint32_t *arr,
719-
int64_t k,
720-
int64_t arrsize,
721-
bool hasnan)
713+
void qselect_<zmm_vector<uint32_t>>(uint32_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
722714
{
723-
if (arrsize > 1) {
724-
qselect_32bit_<zmm_vector<uint32_t>, uint32_t>(
725-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
726-
}
715+
qselect_32bit_<zmm_vector<uint32_t>>(arr, k, left, right, maxiters);
727716
}
728717

729718
template <>
730-
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
719+
void qselect_<zmm_vector<float>>(float* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
731720
{
732-
int64_t indx_last_elem = arrsize - 1;
733-
if (UNLIKELY(hasnan)) {
734-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
735-
}
736-
if (indx_last_elem >= k) {
737-
qselect_32bit_<zmm_vector<float>, float>(
738-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
739-
}
721+
qselect_32bit_<zmm_vector<float>>(arr, k, left, right, maxiters);
740722
}
741723

724+
/* Specialized template function for 32-bit qsort_ funcs*/
742725
template <>
743726
void qsort_<zmm_vector<int32_t>>(int32_t* arr, int64_t left, int64_t right, int64_t maxiters)
744727
{

src/avx512-64bit-qsort.hpp

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -783,46 +783,26 @@ static void qselect_64bit_(type_t *arr,
783783
qselect_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
784784
}
785785

786+
/* Specialized template function for 64-bit qselect_ funcs*/
786787
template <>
787-
void avx512_qselect<int64_t>(int64_t *arr,
788-
int64_t k,
789-
int64_t arrsize,
790-
bool hasnan)
788+
void qselect_<zmm_vector<int64_t>>(int64_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
791789
{
792-
if (arrsize > 1) {
793-
qselect_64bit_<zmm_vector<int64_t>, int64_t>(
794-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
795-
}
790+
qselect_64bit_<zmm_vector<int64_t>>(arr, k, left, right, maxiters);
796791
}
797792

798793
template <>
799-
void avx512_qselect<uint64_t>(uint64_t *arr,
800-
int64_t k,
801-
int64_t arrsize,
802-
bool hasnan)
794+
void qselect_<zmm_vector<uint64_t>>(uint64_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
803795
{
804-
if (arrsize > 1) {
805-
qselect_64bit_<zmm_vector<uint64_t>, uint64_t>(
806-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
807-
}
796+
qselect_64bit_<zmm_vector<uint64_t>>(arr, k, left, right, maxiters);
808797
}
809798

810799
template <>
811-
void avx512_qselect<double>(double *arr,
812-
int64_t k,
813-
int64_t arrsize,
814-
bool hasnan)
800+
void qselect_<zmm_vector<double>>(double* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
815801
{
816-
int64_t indx_last_elem = arrsize - 1;
817-
if (UNLIKELY(hasnan)) {
818-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
819-
}
820-
if (indx_last_elem >= k) {
821-
qselect_64bit_<zmm_vector<double>, double>(
822-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
823-
}
802+
qselect_64bit_<zmm_vector<double>>(arr, k, left, right, maxiters);
824803
}
825804

805+
/* Specialized template function for 64-bit qsort_ funcs*/
826806
template <>
827807
void qsort_<zmm_vector<int64_t>>(int64_t* arr, int64_t left, int64_t right, int64_t maxiters)
828808
{

src/avx512-common-qsort.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,11 +673,15 @@ static inline int64_t partition_avx512(type_t1 *keys,
673673
template <typename vtype, typename type_t>
674674
void qsort_(type_t* arr, int64_t left, int64_t right, int64_t maxiters);
675675

676+
template <typename vtype, typename type_t>
677+
void qselect_(type_t* arr, int64_t pos, int64_t left, int64_t right, int64_t maxiters);
678+
676679
// Regular quicksort routines:
677680
template <typename T>
678681
void avx512_qsort(T *arr, int64_t arrsize)
679682
{
680683
if (arrsize > 1) {
684+
/* std::is_floating_point_v<_Float16> == False, unless c++-23*/
681685
if constexpr (std::is_floating_point_v<T>) {
682686
int64_t nan_count = replace_nan_with_inf<zmm_vector<T>>(arr, arrsize);
683687
qsort_<zmm_vector<T>, T>(
@@ -694,7 +698,21 @@ void avx512_qsort(T *arr, int64_t arrsize)
694698
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize);
695699

696700
template <typename T>
697-
void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
701+
void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
702+
{
703+
int64_t indx_last_elem = arrsize - 1;
704+
/* std::is_floating_point_v<_Float16> == False, unless c++-23*/
705+
if constexpr (std::is_floating_point_v<T>) {
706+
if (UNLIKELY(hasnan)) {
707+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
708+
}
709+
}
710+
if (indx_last_elem >= k) {
711+
qselect_<zmm_vector<T>, T>(
712+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
713+
}
714+
}
715+
698716
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false);
699717

700718
template <typename T>

src/avx512fp16-16bit-qsort.hpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -135,31 +135,26 @@ bool is_a_nan<_Float16>(_Float16 elem)
135135
return (temp.i_ & 0x7c00) == 0x7c00;
136136
}
137137

138-
template <>
139-
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan)
138+
template<>
139+
void replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count)
140140
{
141-
int64_t indx_last_elem = arrsize - 1;
142-
if (UNLIKELY(hasnan)) {
143-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
144-
}
145-
if (indx_last_elem >= k) {
146-
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
147-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
148-
}
141+
memset(arr + arrsize - nan_count, 0xFF, nan_count * 2);
149142
}
150143

151144
template <>
152-
void qsort_<zmm_vector<_Float16>>(_Float16* arr, int64_t left, int64_t right, int64_t maxiters)
145+
void qselect_<zmm_vector<_Float16>>(_Float16* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
153146
{
154-
qsort_16bit_<zmm_vector<_Float16>>(arr, left, right, maxiters);
147+
qselect_16bit_<zmm_vector<_Float16>>(arr, k, left, right, maxiters);
155148
}
156149

157-
template<>
158-
void replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count)
150+
151+
template <>
152+
void qsort_<zmm_vector<_Float16>>(_Float16* arr, int64_t left, int64_t right, int64_t maxiters)
159153
{
160-
memset(arr + arrsize - nan_count, 0xFF, nan_count * 2);
154+
qsort_16bit_<zmm_vector<_Float16>>(arr, left, right, maxiters);
161155
}
162156

157+
/* Specialized template function for _Float16 qsort_*/
163158
template<>
164159
void avx512_qsort(_Float16 *arr, int64_t arrsize)
165160
{

0 commit comments

Comments
 (0)