Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion lib/x86simdsort-avx2.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// AVX2 specific routines:

#include "x86simdsort-static-incl.h"
#include "x86simdsort-internal.h"

Expand Down Expand Up @@ -33,6 +34,38 @@
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
}

#define DEFINE_KEYVALUE_METHODS(type) \
template <> \
void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
}

namespace xss {
namespace avx2 {
DEFINE_ALL_METHODS(uint32_t)
Expand All @@ -41,5 +74,11 @@ namespace avx2 {
DEFINE_ALL_METHODS(uint64_t)
DEFINE_ALL_METHODS(int64_t)
DEFINE_ALL_METHODS(double)
DEFINE_KEYVALUE_METHODS(uint64_t)
DEFINE_KEYVALUE_METHODS(int64_t)
DEFINE_KEYVALUE_METHODS(double)
DEFINE_KEYVALUE_METHODS(uint32_t)
DEFINE_KEYVALUE_METHODS(int32_t)
DEFINE_KEYVALUE_METHODS(float)
} // namespace avx2
} // namespace xss
} // namespace xss
1 change: 1 addition & 0 deletions lib/x86simdsort-skx.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// SKX specific routines:

#include "x86simdsort-static-incl.h"
#include "x86simdsort-internal.h"

Expand Down
12 changes: 6 additions & 6 deletions lib/x86simdsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ DISPATCH_ALL(argselect,
(ISA_LIST("avx512_skx", "avx2")))

#define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \
DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx"))) \
DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx"))) \
DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx"))) \
DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx"))) \
DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx"))) \
DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx", "avx2"))) \
DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx", "avx2"))) \
DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx", "avx2"))) \
DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx", "avx2"))) \
DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx", "avx2"))) \
DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx", "avx2")))

DISPATCH_KEYVALUE_SORT_FORTYPE(uint64_t)
DISPATCH_KEYVALUE_SORT_FORTYPE(int64_t)
Expand Down
12 changes: 12 additions & 0 deletions src/avx2-32bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ struct avx2_vector<int32_t> {
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask(mask);
}
static opmask_t convert_int_to_mask(uint64_t intMask)
{
return convert_int_to_avx2_mask(intMask);
}
static ymmi_t
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
{
Expand Down Expand Up @@ -268,6 +272,10 @@ struct avx2_vector<uint32_t> {
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask(mask);
}
static opmask_t convert_int_to_mask(uint64_t intMask)
{
return convert_int_to_avx2_mask(intMask);
}
static ymmi_t
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
{
Expand Down Expand Up @@ -444,6 +452,10 @@ struct avx2_vector<float> {
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask(mask);
}
static opmask_t convert_int_to_mask(uint64_t intMask)
{
return convert_int_to_avx2_mask(intMask);
}
static int32_t convert_mask_to_int(opmask_t mask)
{
return convert_avx2_mask_to_int(mask);
Expand Down
104 changes: 104 additions & 0 deletions src/avx512-64bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ template <typename vtype, typename reg_t>
X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm);

struct avx512_64bit_swizzle_ops;
struct avx512_ymm_64bit_swizzle_ops;

template <>
struct ymm_vector<float> {
Expand All @@ -34,6 +35,8 @@ struct ymm_vector<float> {
static const uint8_t numlanes = 8;
static constexpr simd_type vec_type = simd_type::AVX512;

using swizzle_ops = avx512_ymm_64bit_swizzle_ops;

static type_t type_max()
{
return X86_SIMD_SORT_INFINITYF;
Expand Down Expand Up @@ -212,6 +215,14 @@ struct ymm_vector<float> {
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
return permutexvar(rev_index, ymm);
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
reg_t reg)
{
return avx512_double_compressstore<ymm_vector<type_t>>(
left_addr, right_addr, k, reg);
}
};
template <>
struct ymm_vector<uint32_t> {
Expand All @@ -222,6 +233,8 @@ struct ymm_vector<uint32_t> {
static const uint8_t numlanes = 8;
static constexpr simd_type vec_type = simd_type::AVX512;

using swizzle_ops = avx512_ymm_64bit_swizzle_ops;

static type_t type_max()
{
return X86_SIMD_SORT_MAX_UINT32;
Expand Down Expand Up @@ -386,6 +399,14 @@ struct ymm_vector<uint32_t> {
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
return permutexvar(rev_index, ymm);
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
reg_t reg)
{
return avx512_double_compressstore<ymm_vector<type_t>>(
left_addr, right_addr, k, reg);
}
};
template <>
struct ymm_vector<int32_t> {
Expand All @@ -396,6 +417,8 @@ struct ymm_vector<int32_t> {
static const uint8_t numlanes = 8;
static constexpr simd_type vec_type = simd_type::AVX512;

using swizzle_ops = avx512_ymm_64bit_swizzle_ops;

static type_t type_max()
{
return X86_SIMD_SORT_MAX_INT32;
Expand Down Expand Up @@ -560,6 +583,14 @@ struct ymm_vector<int32_t> {
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
return permutexvar(rev_index, ymm);
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
reg_t reg)
{
return avx512_double_compressstore<ymm_vector<type_t>>(
left_addr, right_addr, k, reg);
}
};
template <>
struct zmm_vector<int64_t> {
Expand Down Expand Up @@ -1215,4 +1246,77 @@ struct avx512_64bit_swizzle_ops {
}
};

struct avx512_ymm_64bit_swizzle_ops {
template <typename vtype, int scale>
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg)
{
__m256i v = vtype::cast_to(reg);

if constexpr (scale == 2) {
__m256 vf = _mm256_castsi256_ps(v);
vf = _mm256_permute_ps(vf, 0b10110001);
v = _mm256_castps_si256(vf);
}
else if constexpr (scale == 4) {
__m256 vf = _mm256_castsi256_ps(v);
vf = _mm256_permute_ps(vf, 0b01001110);
v = _mm256_castps_si256(vf);
}
else if constexpr (scale == 8) {
v = _mm256_permute2x128_si256(v, v, 0b00000001);
}
else {
static_assert(scale == -1, "should not be reached");
}

return vtype::cast_from(v);
}

template <typename vtype, int scale>
X86_SIMD_SORT_INLINE typename vtype::reg_t
reverse_n(typename vtype::reg_t reg)
{
__m256i v = vtype::cast_to(reg);

if constexpr (scale == 2) { return swap_n<vtype, 2>(reg); }
else if constexpr (scale == 4) {
constexpr uint64_t mask = 0b00011011;
__m256 vf = _mm256_castsi256_ps(v);
vf = _mm256_permute_ps(vf, mask);
v = _mm256_castps_si256(vf);
}
else if constexpr (scale == 8) {
return vtype::reverse(reg);
}
else {
static_assert(scale == -1, "should not be reached");
}

return vtype::cast_from(v);
}

template <typename vtype, int scale>
X86_SIMD_SORT_INLINE typename vtype::reg_t
merge_n(typename vtype::reg_t reg, typename vtype::reg_t other)
{
__m256i v1 = vtype::cast_to(reg);
__m256i v2 = vtype::cast_to(other);

if constexpr (scale == 2) {
v1 = _mm256_blend_epi32(v1, v2, 0b01010101);
}
else if constexpr (scale == 4) {
v1 = _mm256_blend_epi32(v1, v2, 0b00110011);
}
else if constexpr (scale == 8) {
v1 = _mm256_blend_epi32(v1, v2, 0b00001111);
}
else {
static_assert(scale == -1, "should not be reached");
}

return vtype::cast_from(v1);
}
};

#endif
16 changes: 7 additions & 9 deletions src/x86simdsort-static-incl.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,26 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false);
std::iota(indices.begin(), indices.end(), 0); \
x86simdsortStatic::argselect(arr, indices.data(), k, size, hasnan); \
return indices; \
} \
template <typename T1, typename T2> \
X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_qsort( \
T1 *key, T2 *val, size_t size, bool hasnan) \
{ \
ISA##_qsort_kv(key, val, size, hasnan); \
}

/*
* qsort, qselect, partial, argsort key-value sort template functions.
*/
#include "xss-common-qsort.h"
#include "xss-common-argsort.h"
#include "xss-common-keyvaluesort.hpp"

#if defined(__AVX512DQ__) && defined(__AVX512VL__)
/* 32-bit and 64-bit dtypes vector definitions on SKX */
#include "avx512-32bit-qsort.hpp"
#include "avx512-64bit-qsort.hpp"
#include "avx512-64bit-argsort.hpp"
#include "avx512-64bit-keyvaluesort.hpp"

/* 16-bit dtypes vector definitions on ICL */
#if defined(__AVX512BW__) && defined(__AVX512VBMI2__)
Expand All @@ -126,14 +132,6 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false);

XSS_METHODS(avx512)

// key-value currently only on avx512
template <typename T1, typename T2>
X86_SIMD_SORT_FINLINE void
x86simdsortStatic::keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan)
{
avx512_qsort_kv(key, val, size, hasnan);
}

#elif defined(__AVX512F__)
#error "x86simdsort requires AVX512DQ and AVX512VL to be enabled in addition to AVX512F to use AVX512"

Expand Down
Loading