@@ -403,56 +403,51 @@ bool is_a_nan<uint16_t>(uint16_t elem)
403
403
return (elem & 0x7c00 ) == 0x7c00 ;
404
404
}
405
405
406
+ /* Specialized template function for 16-bit qsort_ funcs*/
406
407
template <>
407
- void avx512_qselect (int16_t * arr, int64_t k , int64_t arrsize, bool hasnan )
408
+ void qsort_<zmm_vector< int16_t >> (int16_t * arr, int64_t left , int64_t right, int64_t maxiters )
408
409
{
409
- if (arrsize > 1 ) {
410
- qselect_16bit_<zmm_vector<int16_t >, int16_t >(
411
- arr, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
412
- }
410
+ qsort_16bit_<zmm_vector<int16_t >>(arr, left, right, maxiters);
413
411
}
414
412
415
413
template <>
416
- void avx512_qselect (uint16_t * arr, int64_t k , int64_t arrsize, bool hasnan )
414
+ void qsort_<zmm_vector< uint16_t >> (uint16_t * arr, int64_t left , int64_t right, int64_t maxiters )
417
415
{
418
- if (arrsize > 1 ) {
419
- qselect_16bit_<zmm_vector<uint16_t >, uint16_t >(
420
- arr, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
421
- }
416
+ qsort_16bit_<zmm_vector<uint16_t >>(arr, left, right, maxiters);
422
417
}
423
418
424
- void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan )
419
+ void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize)
425
420
{
426
- int64_t indx_last_elem = arrsize - 1 ;
427
- if (UNLIKELY (hasnan)) {
428
- indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
429
- }
430
- if (indx_last_elem >= k) {
431
- qselect_16bit_<zmm_vector<float16>, uint16_t >(
432
- arr, k, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
421
+ if (arrsize > 1 ) {
422
+ int64_t nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t >(arr, arrsize);
423
+ qsort_16bit_<zmm_vector<float16>, uint16_t >(
424
+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
425
+ replace_inf_with_nan (arr, arrsize, nan_count);
433
426
}
434
427
}
435
428
429
+ /* Specialized template function for 16-bit qselect_ funcs*/
436
430
template <>
437
- void qsort_ <zmm_vector<int16_t >>(int16_t * arr, int64_t left, int64_t right, int64_t maxiters)
431
+ void qselect_ <zmm_vector<int16_t >>(int16_t * arr, int64_t k , int64_t left, int64_t right, int64_t maxiters)
438
432
{
439
- qsort_16bit_ <zmm_vector<int16_t >>(arr, left, right, maxiters);
433
+ qselect_16bit_ <zmm_vector<int16_t >>(arr, k , left, right, maxiters);
440
434
}
441
435
442
436
template <>
443
- void qsort_ <zmm_vector<uint16_t >>(uint16_t * arr, int64_t left, int64_t right, int64_t maxiters)
437
+ void qselect_ <zmm_vector<uint16_t >>(uint16_t * arr, int64_t k , int64_t left, int64_t right, int64_t maxiters)
444
438
{
445
- qsort_16bit_ <zmm_vector<uint16_t >>(arr, left, right, maxiters);
439
+ qselect_16bit_ <zmm_vector<uint16_t >>(arr, k , left, right, maxiters);
446
440
}
447
441
448
- void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize)
442
+ void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan )
449
443
{
450
- if (arrsize > 1 ) {
451
- int64_t nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t >(arr, arrsize);
452
- qsort_16bit_<zmm_vector<float16>, uint16_t >(
453
- arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
454
- replace_inf_with_nan (arr, arrsize, nan_count);
444
+ int64_t indx_last_elem = arrsize - 1 ;
445
+ if (UNLIKELY (hasnan)) {
446
+ indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
447
+ }
448
+ if (indx_last_elem >= k) {
449
+ qselect_16bit_<zmm_vector<float16>, uint16_t >(
450
+ arr, k, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
455
451
}
456
452
}
457
-
458
453
#endif // AVX512_QSORT_16BIT
0 commit comments