@@ -24,6 +24,7 @@ template <typename vtype, typename reg_t>
24
24
X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit (reg_t zmm);
25
25
26
26
struct avx512_64bit_swizzle_ops ;
27
+ struct avx512_ymm_64bit_swizzle_ops ;
27
28
28
29
template <>
29
30
struct ymm_vector <float > {
@@ -34,6 +35,8 @@ struct ymm_vector<float> {
34
35
static const uint8_t numlanes = 8 ;
35
36
static constexpr simd_type vec_type = simd_type::AVX512;
36
37
38
+ using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
39
+
37
40
static type_t type_max ()
38
41
{
39
42
return X86_SIMD_SORT_INFINITYF;
@@ -212,6 +215,14 @@ struct ymm_vector<float> {
212
215
const __m256i rev_index = _mm256_set_epi32 (NETWORK_32BIT_AVX2_2);
213
216
return permutexvar (rev_index, ymm);
214
217
}
218
+ static int double_compressstore (type_t *left_addr,
219
+ type_t *right_addr,
220
+ opmask_t k,
221
+ reg_t reg)
222
+ {
223
+ return avx512_double_compressstore<ymm_vector<type_t >>(
224
+ left_addr, right_addr, k, reg);
225
+ }
215
226
};
216
227
template <>
217
228
struct ymm_vector <uint32_t > {
@@ -222,6 +233,8 @@ struct ymm_vector<uint32_t> {
222
233
static const uint8_t numlanes = 8 ;
223
234
static constexpr simd_type vec_type = simd_type::AVX512;
224
235
236
+ using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
237
+
225
238
static type_t type_max ()
226
239
{
227
240
return X86_SIMD_SORT_MAX_UINT32;
@@ -386,6 +399,14 @@ struct ymm_vector<uint32_t> {
386
399
const __m256i rev_index = _mm256_set_epi32 (NETWORK_32BIT_AVX2_2);
387
400
return permutexvar (rev_index, ymm);
388
401
}
402
+ static int double_compressstore (type_t *left_addr,
403
+ type_t *right_addr,
404
+ opmask_t k,
405
+ reg_t reg)
406
+ {
407
+ return avx512_double_compressstore<ymm_vector<type_t >>(
408
+ left_addr, right_addr, k, reg);
409
+ }
389
410
};
390
411
template <>
391
412
struct ymm_vector <int32_t > {
@@ -396,6 +417,8 @@ struct ymm_vector<int32_t> {
396
417
static const uint8_t numlanes = 8 ;
397
418
static constexpr simd_type vec_type = simd_type::AVX512;
398
419
420
+ using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
421
+
399
422
static type_t type_max ()
400
423
{
401
424
return X86_SIMD_SORT_MAX_INT32;
@@ -560,6 +583,14 @@ struct ymm_vector<int32_t> {
560
583
const __m256i rev_index = _mm256_set_epi32 (NETWORK_32BIT_AVX2_2);
561
584
return permutexvar (rev_index, ymm);
562
585
}
586
+ static int double_compressstore (type_t *left_addr,
587
+ type_t *right_addr,
588
+ opmask_t k,
589
+ reg_t reg)
590
+ {
591
+ return avx512_double_compressstore<ymm_vector<type_t >>(
592
+ left_addr, right_addr, k, reg);
593
+ }
563
594
};
564
595
template <>
565
596
struct zmm_vector <int64_t > {
@@ -1215,4 +1246,77 @@ struct avx512_64bit_swizzle_ops {
1215
1246
}
1216
1247
};
1217
1248
1249
+ struct avx512_ymm_64bit_swizzle_ops {
1250
+ template <typename vtype, int scale>
1251
+ X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n (typename vtype::reg_t reg)
1252
+ {
1253
+ __m256i v = vtype::cast_to (reg);
1254
+
1255
+ if constexpr (scale == 2 ) {
1256
+ __m256 vf = _mm256_castsi256_ps (v);
1257
+ vf = _mm256_permute_ps (vf, 0b10110001 );
1258
+ v = _mm256_castps_si256 (vf);
1259
+ }
1260
+ else if constexpr (scale == 4 ) {
1261
+ __m256 vf = _mm256_castsi256_ps (v);
1262
+ vf = _mm256_permute_ps (vf, 0b01001110 );
1263
+ v = _mm256_castps_si256 (vf);
1264
+ }
1265
+ else if constexpr (scale == 8 ) {
1266
+ v = _mm256_permute2x128_si256 (v, v, 0b00000001 );
1267
+ }
1268
+ else {
1269
+ static_assert (scale == -1 , " should not be reached" );
1270
+ }
1271
+
1272
+ return vtype::cast_from (v);
1273
+ }
1274
+
1275
+ template <typename vtype, int scale>
1276
+ X86_SIMD_SORT_INLINE typename vtype::reg_t
1277
+ reverse_n (typename vtype::reg_t reg)
1278
+ {
1279
+ __m256i v = vtype::cast_to (reg);
1280
+
1281
+ if constexpr (scale == 2 ) { return swap_n<vtype, 2 >(reg); }
1282
+ else if constexpr (scale == 4 ) {
1283
+ constexpr uint64_t mask = 0b00011011 ;
1284
+ __m256 vf = _mm256_castsi256_ps (v);
1285
+ vf = _mm256_permute_ps (vf, mask);
1286
+ v = _mm256_castps_si256 (vf);
1287
+ }
1288
+ else if constexpr (scale == 8 ) {
1289
+ return vtype::reverse (reg);
1290
+ }
1291
+ else {
1292
+ static_assert (scale == -1 , " should not be reached" );
1293
+ }
1294
+
1295
+ return vtype::cast_from (v);
1296
+ }
1297
+
1298
+ template <typename vtype, int scale>
1299
+ X86_SIMD_SORT_INLINE typename vtype::reg_t
1300
+ merge_n (typename vtype::reg_t reg, typename vtype::reg_t other)
1301
+ {
1302
+ __m256i v1 = vtype::cast_to (reg);
1303
+ __m256i v2 = vtype::cast_to (other);
1304
+
1305
+ if constexpr (scale == 2 ) {
1306
+ v1 = _mm256_blend_epi32 (v1, v2, 0b01010101 );
1307
+ }
1308
+ else if constexpr (scale == 4 ) {
1309
+ v1 = _mm256_blend_epi32 (v1, v2, 0b00110011 );
1310
+ }
1311
+ else if constexpr (scale == 8 ) {
1312
+ v1 = _mm256_blend_epi32 (v1, v2, 0b00001111 );
1313
+ }
1314
+ else {
1315
+ static_assert (scale == -1 , " should not be reached" );
1316
+ }
1317
+
1318
+ return vtype::cast_from (v1);
1319
+ }
1320
+ };
1321
+
1218
1322
#endif
0 commit comments