Skip to content

Commit 10dc3dc

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #88 from r-devulap/arg-32bit
Enable argsort and argselect methods on 32-bit platforms
2 parents 59c241b + 8992eaa commit 10dc3dc

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

src/avx512-64bit-common.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ struct ymm_vector<float> {
9595
{
9696
return _mm512_mask_i64gather_ps(src, mask, index, base, scale);
9797
}
98+
template <int scale>
99+
static reg_t
100+
mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base)
101+
{
102+
return _mm256_mmask_i32gather_ps(src, mask, index, base, scale);
103+
}
98104
static reg_t i64gather(type_t *arr, arrsize_t *ind)
99105
{
100106
return set(arr[ind[7]],
@@ -247,6 +253,12 @@ struct ymm_vector<uint32_t> {
247253
{
248254
return _mm512_mask_i64gather_epi32(src, mask, index, base, scale);
249255
}
256+
template <int scale>
257+
static reg_t
258+
mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base)
259+
{
260+
return _mm256_mmask_i32gather_epi32(src, mask, index, base, scale);
261+
}
250262
static reg_t i64gather(type_t *arr, arrsize_t *ind)
251263
{
252264
return set(arr[ind[7]],
@@ -393,6 +405,12 @@ struct ymm_vector<int32_t> {
393405
{
394406
return _mm512_mask_i64gather_epi32(src, mask, index, base, scale);
395407
}
408+
template <int scale>
409+
static reg_t
410+
mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base)
411+
{
412+
return _mm256_mmask_i32gather_epi32(src, mask, index, base, scale);
413+
}
396414
static reg_t i64gather(type_t *arr, arrsize_t *ind)
397415
{
398416
return set(arr[ind[7]],
@@ -548,6 +566,12 @@ struct zmm_vector<int64_t> {
548566
{
549567
return _mm512_mask_i64gather_epi64(src, mask, index, base, scale);
550568
}
569+
template <int scale>
570+
static reg_t
571+
mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base)
572+
{
573+
return _mm512_mask_i32gather_epi64(src, mask, index, base, scale);
574+
}
551575
static reg_t i64gather(type_t *arr, arrsize_t *ind)
552576
{
553577
return set(arr[ind[7]],
@@ -688,6 +712,12 @@ struct zmm_vector<uint64_t> {
688712
{
689713
return _mm512_mask_i64gather_epi64(src, mask, index, base, scale);
690714
}
715+
template <int scale>
716+
static reg_t
717+
mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base)
718+
{
719+
return _mm512_mask_i32gather_epi64(src, mask, index, base, scale);
720+
}
691721
static reg_t i64gather(type_t *arr, arrsize_t *ind)
692722
{
693723
return set(arr[ind[7]],
@@ -864,6 +894,12 @@ struct zmm_vector<double> {
864894
{
865895
return _mm512_mask_i64gather_pd(src, mask, index, base, scale);
866896
}
897+
template <int scale>
898+
static reg_t
899+
mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base)
900+
{
901+
return _mm512_mask_i32gather_pd(src, mask, index, base, scale);
902+
}
867903
static reg_t i64gather(type_t *arr, arrsize_t *ind)
868904
{
869905
return set(arr[ind[7]],

src/avx512-common-argsort.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
#include <stdio.h>
1313
#include <vector>
1414

15-
using argtype = zmm_vector<arrsize_t>;
15+
using argtype = typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
16+
ymm_vector<arrsize_t>,
17+
zmm_vector<arrsize_t>>::type;
1618
using argreg_t = typename argtype::reg_t;
1719

1820
/*

tests/test-qsort.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort)
103103
for (auto type : this->arrtype) {
104104
for (auto size : this->arrsize) {
105105
// k should be at least 1
106-
size_t k = std::max(0x1ul, rand() % size);
106+
size_t k = std::max((size_t)1, rand() % size);
107107
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
108108
std::vector<TypeParam> sortedarr = arr;
109109
std::sort(sortedarr.begin(),

0 commit comments

Comments
 (0)