Skip to content

Commit d424241

Browse files
author
Raghuveer Devulapalli
committed
clang format
1 parent 624df3a commit d424241

12 files changed

+212
-111
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,11 +433,11 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
433433
{
434434
int64_t indx_last_elem = arrsize - 1;
435435
if (UNLIKELY(hasnan)) {
436-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
436+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
437437
}
438438
if (indx_last_elem >= k) {
439439
qselect_16bit_<zmm_vector<float16>, uint16_t>(
440-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
440+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
441441
}
442442
}
443443

src/avx512-32bit-qsort.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,10 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
715715
}
716716

717717
template <>
718-
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
718+
void avx512_qselect<int32_t>(int32_t *arr,
719+
int64_t k,
720+
int64_t arrsize,
721+
bool hasnan)
719722
{
720723
if (arrsize > 1) {
721724
qselect_32bit_<zmm_vector<int32_t>, int32_t>(
@@ -724,7 +727,10 @@ void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize, bool hasn
724727
}
725728

726729
template <>
727-
void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
730+
void avx512_qselect<uint32_t>(uint32_t *arr,
731+
int64_t k,
732+
int64_t arrsize,
733+
bool hasnan)
728734
{
729735
if (arrsize > 1) {
730736
qselect_32bit_<zmm_vector<uint32_t>, uint32_t>(
@@ -737,11 +743,11 @@ void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
737743
{
738744
int64_t indx_last_elem = arrsize - 1;
739745
if (UNLIKELY(hasnan)) {
740-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
746+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
741747
}
742748
if (indx_last_elem >= k) {
743749
qselect_32bit_<zmm_vector<float>, float>(
744-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
750+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
745751
}
746752
}
747753

src/avx512-64bit-argsort.hpp

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,33 +8,45 @@
88
#define AVX512_ARGSORT_64BIT
99

1010
#include "avx512-64bit-common.h"
11-
#include "avx512-common-argsort.h"
1211
#include "avx512-64bit-keyvalue-networks.hpp"
12+
#include "avx512-common-argsort.h"
1313

1414
template <typename T>
15-
void std_argselect_withnan(T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
15+
void std_argselect_withnan(
16+
T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
1617
{
1718
std::nth_element(arg + left,
1819
arg + k,
1920
arg + right,
2021
[arr](int64_t a, int64_t b) -> bool {
21-
if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] < arr[b];}
22-
else if (std::isnan(arr[a])) {return false;}
23-
else {return true;}
22+
if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {
23+
return arr[a] < arr[b];
24+
}
25+
else if (std::isnan(arr[a])) {
26+
return false;
27+
}
28+
else {
29+
return true;
30+
}
2431
});
2532
}
2633

27-
2834
/* argsort using std::sort */
2935
template <typename T>
3036
void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right)
3137
{
3238
std::sort(arg + left,
3339
arg + right,
3440
[arr](int64_t left, int64_t right) -> bool {
35-
if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {return arr[left] < arr[right];}
36-
else if (std::isnan(arr[left])) {return false;}
37-
else {return true;}
41+
if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {
42+
return arr[left] < arr[right];
43+
}
44+
else if (std::isnan(arr[left])) {
45+
return false;
46+
}
47+
else {
48+
return true;
49+
}
3850
});
3951
}
4052

@@ -325,13 +337,15 @@ static void argselect_64bit_(type_t *arr,
325337
int64_t pivot_index = partition_avx512_unrolled<vtype, 4>(
326338
arr, arg, left, right + 1, pivot, &smallest, &biggest);
327339
if ((pivot != smallest) && (pos < pivot_index))
328-
argselect_64bit_<vtype>(arr, arg, pos, left, pivot_index - 1, max_iters - 1);
340+
argselect_64bit_<vtype>(
341+
arr, arg, pos, left, pivot_index - 1, max_iters - 1);
329342
else if ((pivot != biggest) && (pos >= pivot_index))
330-
argselect_64bit_<vtype>(arr, arg, pos, pivot_index, right, max_iters - 1);
343+
argselect_64bit_<vtype>(
344+
arr, arg, pos, pivot_index, right, max_iters - 1);
331345
}
332346

333347
template <typename vtype, typename type_t>
334-
bool has_nan(type_t* arr, int64_t arrsize)
348+
bool has_nan(type_t *arr, int64_t arrsize)
335349
{
336350
using opmask_t = typename vtype::opmask_t;
337351
using zmm_t = typename vtype::zmm_t;
@@ -346,7 +360,7 @@ bool has_nan(type_t* arr, int64_t arrsize)
346360
else {
347361
in = vtype::loadu(arr);
348362
}
349-
opmask_t nanmask = vtype::template fpclass<0x01|0x80>(in);
363+
opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in);
350364
arr += vtype::numlanes;
351365
arrsize -= vtype::numlanes;
352366
if (nanmask != 0x00) {
@@ -357,10 +371,9 @@ bool has_nan(type_t* arr, int64_t arrsize)
357371
return found_nan;
358372
}
359373

360-
361374
/* argsort methods for 32-bit and 64-bit dtypes */
362375
template <typename T>
363-
void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
376+
void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize)
364377
{
365378
if (arrsize > 1) {
366379
argsort_64bit_<zmm_vector<T>>(
@@ -369,7 +382,7 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
369382
}
370383

371384
template <>
372-
void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
385+
void avx512_argsort(double *arr, int64_t *arg, int64_t arrsize)
373386
{
374387
if (arrsize > 1) {
375388
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
@@ -382,9 +395,8 @@ void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
382395
}
383396
}
384397

385-
386398
template <>
387-
void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
399+
void avx512_argsort(int32_t *arr, int64_t *arg, int64_t arrsize)
388400
{
389401
if (arrsize > 1) {
390402
argsort_64bit_<ymm_vector<int32_t>>(
@@ -393,7 +405,7 @@ void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
393405
}
394406

395407
template <>
396-
void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize)
408+
void avx512_argsort(uint32_t *arr, int64_t *arg, int64_t arrsize)
397409
{
398410
if (arrsize > 1) {
399411
argsort_64bit_<ymm_vector<uint32_t>>(
@@ -402,7 +414,7 @@ void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize)
402414
}
403415

404416
template <>
405-
void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
417+
void avx512_argsort(float *arr, int64_t *arg, int64_t arrsize)
406418
{
407419
if (arrsize > 1) {
408420
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
@@ -416,7 +428,7 @@ void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
416428
}
417429

418430
template <typename T>
419-
std::vector<int64_t> avx512_argsort(T* arr, int64_t arrsize)
431+
std::vector<int64_t> avx512_argsort(T *arr, int64_t arrsize)
420432
{
421433
std::vector<int64_t> indices(arrsize);
422434
std::iota(indices.begin(), indices.end(), 0);
@@ -426,7 +438,7 @@ std::vector<int64_t> avx512_argsort(T* arr, int64_t arrsize)
426438

427439
/* argselect methods for 32-bit and 64-bit dtypes */
428440
template <typename T>
429-
void avx512_argselect(T* arr, int64_t *arg, int64_t k, int64_t arrsize)
441+
void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize)
430442
{
431443
if (arrsize > 1) {
432444
argselect_64bit_<zmm_vector<T>>(
@@ -435,7 +447,7 @@ void avx512_argselect(T* arr, int64_t *arg, int64_t k, int64_t arrsize)
435447
}
436448

437449
template <>
438-
void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize)
450+
void avx512_argselect(double *arr, int64_t *arg, int64_t k, int64_t arrsize)
439451
{
440452
if (arrsize > 1) {
441453
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
@@ -449,7 +461,7 @@ void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize)
449461
}
450462

451463
template <>
452-
void avx512_argselect(int32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
464+
void avx512_argselect(int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
453465
{
454466
if (arrsize > 1) {
455467
argselect_64bit_<ymm_vector<int32_t>>(
@@ -458,7 +470,7 @@ void avx512_argselect(int32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
458470
}
459471

460472
template <>
461-
void avx512_argselect(uint32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
473+
void avx512_argselect(uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
462474
{
463475
if (arrsize > 1) {
464476
argselect_64bit_<ymm_vector<uint32_t>>(
@@ -467,7 +479,7 @@ void avx512_argselect(uint32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
467479
}
468480

469481
template <>
470-
void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize)
482+
void avx512_argselect(float *arr, int64_t *arg, int64_t k, int64_t arrsize)
471483
{
472484
if (arrsize > 1) {
473485
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
@@ -481,7 +493,7 @@ void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize)
481493
}
482494

483495
template <typename T>
484-
std::vector<int64_t> avx512_argselect(T* arr, int64_t k, int64_t arrsize)
496+
std::vector<int64_t> avx512_argselect(T *arr, int64_t k, int64_t arrsize)
485497
{
486498
std::vector<int64_t> indices(arrsize);
487499
std::iota(indices.begin(), indices.end(), 0);

src/avx512-64bit-keyvalue-networks.hpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
136136
typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]);
137137
typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]);
138138

139-
index_type index_zmm_t1 = vtype2::mask_mov(
140-
index_zmm3r, movmask1, index_zmm[0]);
141-
index_type index_zmm_m1 = vtype2::mask_mov(
142-
index_zmm[0], movmask1, index_zmm3r);
143-
index_type index_zmm_t2 = vtype2::mask_mov(
144-
index_zmm2r, movmask2, index_zmm[1]);
145-
index_type index_zmm_m2 = vtype2::mask_mov(
146-
index_zmm[1], movmask2, index_zmm2r);
139+
index_type index_zmm_t1
140+
= vtype2::mask_mov(index_zmm3r, movmask1, index_zmm[0]);
141+
index_type index_zmm_m1
142+
= vtype2::mask_mov(index_zmm[0], movmask1, index_zmm3r);
143+
index_type index_zmm_t2
144+
= vtype2::mask_mov(index_zmm2r, movmask2, index_zmm[1]);
145+
index_type index_zmm_m2
146+
= vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r);
147147

148148
// 2) Recursive half clearer: 16
149149
zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2);
@@ -159,14 +159,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
159159
movmask1 = vtype1::eq(key_zmm0, key_zmm_t1);
160160
movmask2 = vtype1::eq(key_zmm2, key_zmm_t3);
161161

162-
index_type index_zmm0 = vtype2::mask_mov(
163-
index_zmm_t2, movmask1, index_zmm_t1);
164-
index_type index_zmm1 = vtype2::mask_mov(
165-
index_zmm_t1, movmask1, index_zmm_t2);
166-
index_type index_zmm2 = vtype2::mask_mov(
167-
index_zmm_t4, movmask2, index_zmm_t3);
168-
index_type index_zmm3 = vtype2::mask_mov(
169-
index_zmm_t3, movmask2, index_zmm_t4);
162+
index_type index_zmm0
163+
= vtype2::mask_mov(index_zmm_t2, movmask1, index_zmm_t1);
164+
index_type index_zmm1
165+
= vtype2::mask_mov(index_zmm_t1, movmask1, index_zmm_t2);
166+
index_type index_zmm2
167+
= vtype2::mask_mov(index_zmm_t4, movmask2, index_zmm_t3);
168+
index_type index_zmm3
169+
= vtype2::mask_mov(index_zmm_t3, movmask2, index_zmm_t4);
170170

171171
key_zmm[0] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm0, index_zmm0);
172172
key_zmm[1] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm1, index_zmm1);
@@ -212,22 +212,22 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,
212212
typename vtype1::opmask_t movmask3 = vtype1::eq(key_zmm_t3, key_zmm[2]);
213213
typename vtype1::opmask_t movmask4 = vtype1::eq(key_zmm_t4, key_zmm[3]);
214214

215-
index_type index_zmm_t1 = vtype2::mask_mov(
216-
index_zmm7r, movmask1, index_zmm[0]);
217-
index_type index_zmm_m1 = vtype2::mask_mov(
218-
index_zmm[0], movmask1, index_zmm7r);
219-
index_type index_zmm_t2 = vtype2::mask_mov(
220-
index_zmm6r, movmask2, index_zmm[1]);
221-
index_type index_zmm_m2 = vtype2::mask_mov(
222-
index_zmm[1], movmask2, index_zmm6r);
223-
index_type index_zmm_t3 = vtype2::mask_mov(
224-
index_zmm5r, movmask3, index_zmm[2]);
225-
index_type index_zmm_m3 = vtype2::mask_mov(
226-
index_zmm[2], movmask3, index_zmm5r);
227-
index_type index_zmm_t4 = vtype2::mask_mov(
228-
index_zmm4r, movmask4, index_zmm[3]);
229-
index_type index_zmm_m4 = vtype2::mask_mov(
230-
index_zmm[3], movmask4, index_zmm4r);
215+
index_type index_zmm_t1
216+
= vtype2::mask_mov(index_zmm7r, movmask1, index_zmm[0]);
217+
index_type index_zmm_m1
218+
= vtype2::mask_mov(index_zmm[0], movmask1, index_zmm7r);
219+
index_type index_zmm_t2
220+
= vtype2::mask_mov(index_zmm6r, movmask2, index_zmm[1]);
221+
index_type index_zmm_m2
222+
= vtype2::mask_mov(index_zmm[1], movmask2, index_zmm6r);
223+
index_type index_zmm_t3
224+
= vtype2::mask_mov(index_zmm5r, movmask3, index_zmm[2]);
225+
index_type index_zmm_m3
226+
= vtype2::mask_mov(index_zmm[2], movmask3, index_zmm5r);
227+
index_type index_zmm_t4
228+
= vtype2::mask_mov(index_zmm4r, movmask4, index_zmm[3]);
229+
index_type index_zmm_m4
230+
= vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r);
231231

232232
zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4);
233233
zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3);

src/avx512-64bit-qsort.hpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,10 @@ static void qselect_64bit_(type_t *arr,
784784
}
785785

786786
template <>
787-
void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
787+
void avx512_qselect<int64_t>(int64_t *arr,
788+
int64_t k,
789+
int64_t arrsize,
790+
bool hasnan)
788791
{
789792
if (arrsize > 1) {
790793
qselect_64bit_<zmm_vector<int64_t>, int64_t>(
@@ -793,7 +796,10 @@ void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize, bool hasn
793796
}
794797

795798
template <>
796-
void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
799+
void avx512_qselect<uint64_t>(uint64_t *arr,
800+
int64_t k,
801+
int64_t arrsize,
802+
bool hasnan)
797803
{
798804
if (arrsize > 1) {
799805
qselect_64bit_<zmm_vector<uint64_t>, uint64_t>(
@@ -802,15 +808,18 @@ void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize, bool ha
802808
}
803809

804810
template <>
805-
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize, bool hasnan)
811+
void avx512_qselect<double>(double *arr,
812+
int64_t k,
813+
int64_t arrsize,
814+
bool hasnan)
806815
{
807816
int64_t indx_last_elem = arrsize - 1;
808817
if (UNLIKELY(hasnan)) {
809-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
818+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
810819
}
811820
if (indx_last_elem >= k) {
812821
qselect_64bit_<zmm_vector<double>, double>(
813-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
822+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
814823
}
815824
}
816825

src/avx512fp16-16bit-qsort.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan)
157157
{
158158
int64_t indx_last_elem = arrsize - 1;
159159
if (UNLIKELY(hasnan)) {
160-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
160+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
161161
}
162162
if (indx_last_elem >= k) {
163163
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
164-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
164+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
165165
}
166166
}
167167

0 commit comments

Comments
 (0)