Skip to content

Commit 3f958c7

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #145 from sterrettm2/kv-avx2
AVX2 Support for key-value sorting
2 parents aad3db1 + 3607752 commit 3f958c7

8 files changed

+321
-105
lines changed

lib/x86simdsort-avx2.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// AVX2 specific routines:
2+
23
#include "x86simdsort-static-incl.h"
34
#include "x86simdsort-internal.h"
45

@@ -33,6 +34,38 @@
3334
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
3435
}
3536

37+
#define DEFINE_KEYVALUE_METHODS(type) \
38+
template <> \
39+
void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \
40+
{ \
41+
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
42+
} \
43+
template <> \
44+
void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \
45+
{ \
46+
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
47+
} \
48+
template <> \
49+
void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \
50+
{ \
51+
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
52+
} \
53+
template <> \
54+
void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \
55+
{ \
56+
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
57+
} \
58+
template <> \
59+
void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \
60+
{ \
61+
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
62+
} \
63+
template <> \
64+
void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \
65+
{ \
66+
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
67+
}
68+
3669
namespace xss {
3770
namespace avx2 {
3871
DEFINE_ALL_METHODS(uint32_t)
@@ -41,5 +74,11 @@ namespace avx2 {
4174
DEFINE_ALL_METHODS(uint64_t)
4275
DEFINE_ALL_METHODS(int64_t)
4376
DEFINE_ALL_METHODS(double)
77+
DEFINE_KEYVALUE_METHODS(uint64_t)
78+
DEFINE_KEYVALUE_METHODS(int64_t)
79+
DEFINE_KEYVALUE_METHODS(double)
80+
DEFINE_KEYVALUE_METHODS(uint32_t)
81+
DEFINE_KEYVALUE_METHODS(int32_t)
82+
DEFINE_KEYVALUE_METHODS(float)
4483
} // namespace avx2
45-
} // namespace xss
84+
} // namespace xss

lib/x86simdsort-skx.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// SKX specific routines:
2+
23
#include "x86simdsort-static-incl.h"
34
#include "x86simdsort-internal.h"
45

lib/x86simdsort.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,12 @@ DISPATCH_ALL(argselect,
208208
(ISA_LIST("avx512_skx", "avx2")))
209209

210210
#define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \
211-
DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx"))) \
212-
DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx"))) \
213-
DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx"))) \
214-
DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx"))) \
215-
DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx"))) \
216-
DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx")))
211+
DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx", "avx2"))) \
212+
DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx", "avx2"))) \
213+
DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx", "avx2"))) \
214+
DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx", "avx2"))) \
215+
DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx", "avx2"))) \
216+
DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx", "avx2")))
217217

218218
DISPATCH_KEYVALUE_SORT_FORTYPE(uint64_t)
219219
DISPATCH_KEYVALUE_SORT_FORTYPE(int64_t)

src/avx2-32bit-qsort.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ struct avx2_vector<int32_t> {
9999
auto mask = ((0x1ull << num_to_read) - 0x1ull);
100100
return convert_int_to_avx2_mask(mask);
101101
}
102+
static opmask_t convert_int_to_mask(uint64_t intMask)
103+
{
104+
return convert_int_to_avx2_mask(intMask);
105+
}
102106
static ymmi_t
103107
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
104108
{
@@ -268,6 +272,10 @@ struct avx2_vector<uint32_t> {
268272
auto mask = ((0x1ull << num_to_read) - 0x1ull);
269273
return convert_int_to_avx2_mask(mask);
270274
}
275+
static opmask_t convert_int_to_mask(uint64_t intMask)
276+
{
277+
return convert_int_to_avx2_mask(intMask);
278+
}
271279
static ymmi_t
272280
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
273281
{
@@ -444,6 +452,10 @@ struct avx2_vector<float> {
444452
auto mask = ((0x1ull << num_to_read) - 0x1ull);
445453
return convert_int_to_avx2_mask(mask);
446454
}
455+
static opmask_t convert_int_to_mask(uint64_t intMask)
456+
{
457+
return convert_int_to_avx2_mask(intMask);
458+
}
447459
static int32_t convert_mask_to_int(opmask_t mask)
448460
{
449461
return convert_avx2_mask_to_int(mask);

src/avx512-64bit-common.h

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ template <typename vtype, typename reg_t>
2424
X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm);
2525

2626
struct avx512_64bit_swizzle_ops;
27+
struct avx512_ymm_64bit_swizzle_ops;
2728

2829
template <>
2930
struct ymm_vector<float> {
@@ -34,6 +35,8 @@ struct ymm_vector<float> {
3435
static const uint8_t numlanes = 8;
3536
static constexpr simd_type vec_type = simd_type::AVX512;
3637

38+
using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
39+
3740
static type_t type_max()
3841
{
3942
return X86_SIMD_SORT_INFINITYF;
@@ -212,6 +215,14 @@ struct ymm_vector<float> {
212215
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
213216
return permutexvar(rev_index, ymm);
214217
}
218+
static int double_compressstore(type_t *left_addr,
219+
type_t *right_addr,
220+
opmask_t k,
221+
reg_t reg)
222+
{
223+
return avx512_double_compressstore<ymm_vector<type_t>>(
224+
left_addr, right_addr, k, reg);
225+
}
215226
};
216227
template <>
217228
struct ymm_vector<uint32_t> {
@@ -222,6 +233,8 @@ struct ymm_vector<uint32_t> {
222233
static const uint8_t numlanes = 8;
223234
static constexpr simd_type vec_type = simd_type::AVX512;
224235

236+
using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
237+
225238
static type_t type_max()
226239
{
227240
return X86_SIMD_SORT_MAX_UINT32;
@@ -386,6 +399,14 @@ struct ymm_vector<uint32_t> {
386399
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
387400
return permutexvar(rev_index, ymm);
388401
}
402+
static int double_compressstore(type_t *left_addr,
403+
type_t *right_addr,
404+
opmask_t k,
405+
reg_t reg)
406+
{
407+
return avx512_double_compressstore<ymm_vector<type_t>>(
408+
left_addr, right_addr, k, reg);
409+
}
389410
};
390411
template <>
391412
struct ymm_vector<int32_t> {
@@ -396,6 +417,8 @@ struct ymm_vector<int32_t> {
396417
static const uint8_t numlanes = 8;
397418
static constexpr simd_type vec_type = simd_type::AVX512;
398419

420+
using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
421+
399422
static type_t type_max()
400423
{
401424
return X86_SIMD_SORT_MAX_INT32;
@@ -560,6 +583,14 @@ struct ymm_vector<int32_t> {
560583
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
561584
return permutexvar(rev_index, ymm);
562585
}
586+
static int double_compressstore(type_t *left_addr,
587+
type_t *right_addr,
588+
opmask_t k,
589+
reg_t reg)
590+
{
591+
return avx512_double_compressstore<ymm_vector<type_t>>(
592+
left_addr, right_addr, k, reg);
593+
}
563594
};
564595
template <>
565596
struct zmm_vector<int64_t> {
@@ -1215,4 +1246,77 @@ struct avx512_64bit_swizzle_ops {
12151246
}
12161247
};
12171248

1249+
struct avx512_ymm_64bit_swizzle_ops {
1250+
template <typename vtype, int scale>
1251+
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg)
1252+
{
1253+
__m256i v = vtype::cast_to(reg);
1254+
1255+
if constexpr (scale == 2) {
1256+
__m256 vf = _mm256_castsi256_ps(v);
1257+
vf = _mm256_permute_ps(vf, 0b10110001);
1258+
v = _mm256_castps_si256(vf);
1259+
}
1260+
else if constexpr (scale == 4) {
1261+
__m256 vf = _mm256_castsi256_ps(v);
1262+
vf = _mm256_permute_ps(vf, 0b01001110);
1263+
v = _mm256_castps_si256(vf);
1264+
}
1265+
else if constexpr (scale == 8) {
1266+
v = _mm256_permute2x128_si256(v, v, 0b00000001);
1267+
}
1268+
else {
1269+
static_assert(scale == -1, "should not be reached");
1270+
}
1271+
1272+
return vtype::cast_from(v);
1273+
}
1274+
1275+
template <typename vtype, int scale>
1276+
X86_SIMD_SORT_INLINE typename vtype::reg_t
1277+
reverse_n(typename vtype::reg_t reg)
1278+
{
1279+
__m256i v = vtype::cast_to(reg);
1280+
1281+
if constexpr (scale == 2) { return swap_n<vtype, 2>(reg); }
1282+
else if constexpr (scale == 4) {
1283+
constexpr uint64_t mask = 0b00011011;
1284+
__m256 vf = _mm256_castsi256_ps(v);
1285+
vf = _mm256_permute_ps(vf, mask);
1286+
v = _mm256_castps_si256(vf);
1287+
}
1288+
else if constexpr (scale == 8) {
1289+
return vtype::reverse(reg);
1290+
}
1291+
else {
1292+
static_assert(scale == -1, "should not be reached");
1293+
}
1294+
1295+
return vtype::cast_from(v);
1296+
}
1297+
1298+
template <typename vtype, int scale>
1299+
X86_SIMD_SORT_INLINE typename vtype::reg_t
1300+
merge_n(typename vtype::reg_t reg, typename vtype::reg_t other)
1301+
{
1302+
__m256i v1 = vtype::cast_to(reg);
1303+
__m256i v2 = vtype::cast_to(other);
1304+
1305+
if constexpr (scale == 2) {
1306+
v1 = _mm256_blend_epi32(v1, v2, 0b01010101);
1307+
}
1308+
else if constexpr (scale == 4) {
1309+
v1 = _mm256_blend_epi32(v1, v2, 0b00110011);
1310+
}
1311+
else if constexpr (scale == 8) {
1312+
v1 = _mm256_blend_epi32(v1, v2, 0b00001111);
1313+
}
1314+
else {
1315+
static_assert(scale == -1, "should not be reached");
1316+
}
1317+
1318+
return vtype::cast_from(v1);
1319+
}
1320+
};
1321+
12181322
#endif

src/x86simdsort-static-incl.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,26 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false);
100100
std::iota(indices.begin(), indices.end(), 0); \
101101
x86simdsortStatic::argselect(arr, indices.data(), k, size, hasnan); \
102102
return indices; \
103+
} \
104+
template <typename T1, typename T2> \
105+
X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_qsort( \
106+
T1 *key, T2 *val, size_t size, bool hasnan) \
107+
{ \
108+
ISA##_qsort_kv(key, val, size, hasnan); \
103109
}
104110

105111
/*
106112
* qsort, qselect, partial, argsort key-value sort template functions.
107113
*/
108114
#include "xss-common-qsort.h"
109115
#include "xss-common-argsort.h"
116+
#include "xss-common-keyvaluesort.hpp"
110117

111118
#if defined(__AVX512DQ__) && defined(__AVX512VL__)
112119
/* 32-bit and 64-bit dtypes vector definitions on SKX */
113120
#include "avx512-32bit-qsort.hpp"
114121
#include "avx512-64bit-qsort.hpp"
115122
#include "avx512-64bit-argsort.hpp"
116-
#include "avx512-64bit-keyvaluesort.hpp"
117123

118124
/* 16-bit dtypes vector definitions on ICL */
119125
#if defined(__AVX512BW__) && defined(__AVX512VBMI2__)
@@ -126,14 +132,6 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false);
126132

127133
XSS_METHODS(avx512)
128134

129-
// key-value currently only on avx512
130-
template <typename T1, typename T2>
131-
X86_SIMD_SORT_FINLINE void
132-
x86simdsortStatic::keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan)
133-
{
134-
avx512_qsort_kv(key, val, size, hasnan);
135-
}
136-
137135
#elif defined(__AVX512F__)
138136
#error "x86simdsort requires AVX512DQ and AVX512VL to be enabled in addition to AVX512F to use AVX512"
139137

0 commit comments

Comments
 (0)