8
8
#define AVX512_ARGSORT_64BIT
9
9
10
10
#include " avx512-64bit-common.h"
11
- #include " avx512-common-argsort.h"
12
11
#include " avx512-64bit-keyvalue-networks.hpp"
12
+ #include " avx512-common-argsort.h"
13
13
14
14
template <typename T>
15
- void std_argselect_withnan (T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
15
+ void std_argselect_withnan (
16
+ T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
16
17
{
17
18
std::nth_element (arg + left,
18
19
arg + k,
19
20
arg + right,
20
21
[arr](int64_t a, int64_t b) -> bool {
21
- if ((!std::isnan (arr[a])) && (!std::isnan (arr[b]))) {return arr[a] < arr[b];}
22
- else if (std::isnan (arr[a])) {return false ;}
23
- else {return true ;}
22
+ if ((!std::isnan (arr[a])) && (!std::isnan (arr[b]))) {
23
+ return arr[a] < arr[b];
24
+ }
25
+ else if (std::isnan (arr[a])) {
26
+ return false ;
27
+ }
28
+ else {
29
+ return true ;
30
+ }
24
31
});
25
32
}
26
33
27
-
28
34
/* argsort using std::sort */
29
35
template <typename T>
30
36
void std_argsort_withnan (T *arr, int64_t *arg, int64_t left, int64_t right)
31
37
{
32
38
std::sort (arg + left,
33
39
arg + right,
34
40
[arr](int64_t left, int64_t right) -> bool {
35
- if ((!std::isnan (arr[left])) && (!std::isnan (arr[right]))) {return arr[left] < arr[right];}
36
- else if (std::isnan (arr[left])) {return false ;}
37
- else {return true ;}
41
+ if ((!std::isnan (arr[left])) && (!std::isnan (arr[right]))) {
42
+ return arr[left] < arr[right];
43
+ }
44
+ else if (std::isnan (arr[left])) {
45
+ return false ;
46
+ }
47
+ else {
48
+ return true ;
49
+ }
38
50
});
39
51
}
40
52
@@ -325,13 +337,15 @@ static void argselect_64bit_(type_t *arr,
325
337
int64_t pivot_index = partition_avx512_unrolled<vtype, 4 >(
326
338
arr, arg, left, right + 1 , pivot, &smallest, &biggest);
327
339
if ((pivot != smallest) && (pos < pivot_index))
328
- argselect_64bit_<vtype>(arr, arg, pos, left, pivot_index - 1 , max_iters - 1 );
340
+ argselect_64bit_<vtype>(
341
+ arr, arg, pos, left, pivot_index - 1 , max_iters - 1 );
329
342
else if ((pivot != biggest) && (pos >= pivot_index))
330
- argselect_64bit_<vtype>(arr, arg, pos, pivot_index, right, max_iters - 1 );
343
+ argselect_64bit_<vtype>(
344
+ arr, arg, pos, pivot_index, right, max_iters - 1 );
331
345
}
332
346
333
347
template <typename vtype, typename type_t >
334
- bool has_nan (type_t * arr, int64_t arrsize)
348
+ bool has_nan (type_t * arr, int64_t arrsize)
335
349
{
336
350
using opmask_t = typename vtype::opmask_t ;
337
351
using zmm_t = typename vtype::zmm_t ;
@@ -346,7 +360,7 @@ bool has_nan(type_t* arr, int64_t arrsize)
346
360
else {
347
361
in = vtype::loadu (arr);
348
362
}
349
- opmask_t nanmask = vtype::template fpclass<0x01 | 0x80 >(in);
363
+ opmask_t nanmask = vtype::template fpclass<0x01 | 0x80 >(in);
350
364
arr += vtype::numlanes;
351
365
arrsize -= vtype::numlanes;
352
366
if (nanmask != 0x00 ) {
@@ -357,10 +371,9 @@ bool has_nan(type_t* arr, int64_t arrsize)
357
371
return found_nan;
358
372
}
359
373
360
-
361
374
/* argsort methods for 32-bit and 64-bit dtypes */
362
375
template <typename T>
363
- void avx512_argsort (T* arr, int64_t *arg, int64_t arrsize)
376
+ void avx512_argsort (T * arr, int64_t *arg, int64_t arrsize)
364
377
{
365
378
if (arrsize > 1 ) {
366
379
argsort_64bit_<zmm_vector<T>>(
@@ -369,7 +382,7 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
369
382
}
370
383
371
384
template <>
372
- void avx512_argsort (double * arr, int64_t *arg, int64_t arrsize)
385
+ void avx512_argsort (double * arr, int64_t *arg, int64_t arrsize)
373
386
{
374
387
if (arrsize > 1 ) {
375
388
if (has_nan<zmm_vector<double >>(arr, arrsize)) {
@@ -382,9 +395,8 @@ void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
382
395
}
383
396
}
384
397
385
-
386
398
template <>
387
- void avx512_argsort (int32_t * arr, int64_t *arg, int64_t arrsize)
399
+ void avx512_argsort (int32_t * arr, int64_t *arg, int64_t arrsize)
388
400
{
389
401
if (arrsize > 1 ) {
390
402
argsort_64bit_<ymm_vector<int32_t >>(
@@ -393,7 +405,7 @@ void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
393
405
}
394
406
395
407
template <>
396
- void avx512_argsort (uint32_t * arr, int64_t *arg, int64_t arrsize)
408
+ void avx512_argsort (uint32_t * arr, int64_t *arg, int64_t arrsize)
397
409
{
398
410
if (arrsize > 1 ) {
399
411
argsort_64bit_<ymm_vector<uint32_t >>(
@@ -402,7 +414,7 @@ void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize)
402
414
}
403
415
404
416
template <>
405
- void avx512_argsort (float * arr, int64_t *arg, int64_t arrsize)
417
+ void avx512_argsort (float * arr, int64_t *arg, int64_t arrsize)
406
418
{
407
419
if (arrsize > 1 ) {
408
420
if (has_nan<ymm_vector<float >>(arr, arrsize)) {
@@ -416,7 +428,7 @@ void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
416
428
}
417
429
418
430
template <typename T>
419
- std::vector<int64_t > avx512_argsort (T* arr, int64_t arrsize)
431
+ std::vector<int64_t > avx512_argsort (T * arr, int64_t arrsize)
420
432
{
421
433
std::vector<int64_t > indices (arrsize);
422
434
std::iota (indices.begin (), indices.end (), 0 );
@@ -426,7 +438,7 @@ std::vector<int64_t> avx512_argsort(T* arr, int64_t arrsize)
426
438
427
439
/* argselect methods for 32-bit and 64-bit dtypes */
428
440
template <typename T>
429
- void avx512_argselect (T* arr, int64_t *arg, int64_t k, int64_t arrsize)
441
+ void avx512_argselect (T * arr, int64_t *arg, int64_t k, int64_t arrsize)
430
442
{
431
443
if (arrsize > 1 ) {
432
444
argselect_64bit_<zmm_vector<T>>(
@@ -435,7 +447,7 @@ void avx512_argselect(T* arr, int64_t *arg, int64_t k, int64_t arrsize)
435
447
}
436
448
437
449
template <>
438
- void avx512_argselect (double * arr, int64_t *arg, int64_t k, int64_t arrsize)
450
+ void avx512_argselect (double * arr, int64_t *arg, int64_t k, int64_t arrsize)
439
451
{
440
452
if (arrsize > 1 ) {
441
453
if (has_nan<zmm_vector<double >>(arr, arrsize)) {
@@ -449,7 +461,7 @@ void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize)
449
461
}
450
462
451
463
template <>
452
- void avx512_argselect (int32_t * arr, int64_t *arg, int64_t k, int64_t arrsize)
464
+ void avx512_argselect (int32_t * arr, int64_t *arg, int64_t k, int64_t arrsize)
453
465
{
454
466
if (arrsize > 1 ) {
455
467
argselect_64bit_<ymm_vector<int32_t >>(
@@ -458,7 +470,7 @@ void avx512_argselect(int32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
458
470
}
459
471
460
472
template <>
461
- void avx512_argselect (uint32_t * arr, int64_t *arg, int64_t k, int64_t arrsize)
473
+ void avx512_argselect (uint32_t * arr, int64_t *arg, int64_t k, int64_t arrsize)
462
474
{
463
475
if (arrsize > 1 ) {
464
476
argselect_64bit_<ymm_vector<uint32_t >>(
@@ -467,7 +479,7 @@ void avx512_argselect(uint32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
467
479
}
468
480
469
481
template <>
470
- void avx512_argselect (float * arr, int64_t *arg, int64_t k, int64_t arrsize)
482
+ void avx512_argselect (float * arr, int64_t *arg, int64_t k, int64_t arrsize)
471
483
{
472
484
if (arrsize > 1 ) {
473
485
if (has_nan<ymm_vector<float >>(arr, arrsize)) {
@@ -481,7 +493,7 @@ void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize)
481
493
}
482
494
483
495
template <typename T>
484
- std::vector<int64_t > avx512_argselect (T* arr, int64_t k, int64_t arrsize)
496
+ std::vector<int64_t > avx512_argselect (T * arr, int64_t k, int64_t arrsize)
485
497
{
486
498
std::vector<int64_t > indices (arrsize);
487
499
std::iota (indices.begin (), indices.end (), 0 );
0 commit comments