8
8
#define AVX512_16BIT_COMMON
9
9
10
10
#include " avx512-common-qsort.h"
11
+ #include " xss-network-qsort.hpp"
11
12
12
13
/*
13
14
* Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic
@@ -33,8 +34,8 @@ static const uint16_t network[6][32]
33
34
* Assumes zmm is random and performs a full sorting network defined in
34
35
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
35
36
*/
36
- template <typename vtype, typename zmm_t = typename vtype::zmm_t >
37
- X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit (zmm_t zmm)
37
+ template <typename vtype, typename reg_t = typename vtype::reg_t >
38
+ X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit (reg_t zmm)
38
39
{
39
40
// Level 1
40
41
zmm = cmp_merge<vtype>(
@@ -93,8 +94,8 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit(zmm_t zmm)
93
94
}
94
95
95
96
// Assumes zmm is bitonic and performs a recursive half cleaner
96
- template <typename vtype, typename zmm_t = typename vtype::zmm_t >
97
- X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit (zmm_t zmm)
97
+ template <typename vtype, typename reg_t = typename vtype::reg_t >
98
+ X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_16bit (reg_t zmm)
98
99
{
99
100
// 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc ..
100
101
zmm = cmp_merge<vtype>(
@@ -118,208 +119,4 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
118
119
return zmm;
119
120
}
120
121
121
- // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
122
- template <typename vtype, typename zmm_t = typename vtype::zmm_t >
123
- X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit (zmm_t &zmm1, zmm_t &zmm2)
124
- {
125
- // 1) First step of a merging network: coex of zmm1 and zmm2 reversed
126
- zmm2 = vtype::permutexvar (vtype::get_network (4 ), zmm2);
127
- zmm_t zmm3 = vtype::min (zmm1, zmm2);
128
- zmm_t zmm4 = vtype::max (zmm1, zmm2);
129
- // 2) Recursive half cleaner for each
130
- zmm1 = bitonic_merge_zmm_16bit<vtype>(zmm3);
131
- zmm2 = bitonic_merge_zmm_16bit<vtype>(zmm4);
132
- }
133
-
134
- // Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
135
- // half cleaner
136
- template <typename vtype, typename zmm_t = typename vtype::zmm_t >
137
- X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit (zmm_t *zmm)
138
- {
139
- zmm_t zmm2r = vtype::permutexvar (vtype::get_network (4 ), zmm[2 ]);
140
- zmm_t zmm3r = vtype::permutexvar (vtype::get_network (4 ), zmm[3 ]);
141
- zmm_t zmm_t1 = vtype::min (zmm[0 ], zmm3r);
142
- zmm_t zmm_t2 = vtype::min (zmm[1 ], zmm2r);
143
- zmm_t zmm_t3 = vtype::permutexvar (vtype::get_network (4 ),
144
- vtype::max (zmm[1 ], zmm2r));
145
- zmm_t zmm_t4 = vtype::permutexvar (vtype::get_network (4 ),
146
- vtype::max (zmm[0 ], zmm3r));
147
- zmm_t zmm0 = vtype::min (zmm_t1, zmm_t2);
148
- zmm_t zmm1 = vtype::max (zmm_t1, zmm_t2);
149
- zmm_t zmm2 = vtype::min (zmm_t3, zmm_t4);
150
- zmm_t zmm3 = vtype::max (zmm_t3, zmm_t4);
151
- zmm[0 ] = bitonic_merge_zmm_16bit<vtype>(zmm0);
152
- zmm[1 ] = bitonic_merge_zmm_16bit<vtype>(zmm1);
153
- zmm[2 ] = bitonic_merge_zmm_16bit<vtype>(zmm2);
154
- zmm[3 ] = bitonic_merge_zmm_16bit<vtype>(zmm3);
155
- }
156
-
157
- template <typename vtype, typename type_t >
158
- X86_SIMD_SORT_INLINE void sort_32_16bit (type_t *arr, int32_t N)
159
- {
160
- typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull ) & 0xFFFFFFFF ;
161
- typename vtype::zmm_t zmm
162
- = vtype::mask_loadu (vtype::zmm_max (), load_mask, arr);
163
- vtype::mask_storeu (arr, load_mask, sort_zmm_16bit<vtype>(zmm));
164
- }
165
-
166
- template <typename vtype, typename type_t >
167
- X86_SIMD_SORT_INLINE void sort_64_16bit (type_t *arr, int32_t N)
168
- {
169
- if (N <= 32 ) {
170
- sort_32_16bit<vtype>(arr, N);
171
- return ;
172
- }
173
- using zmm_t = typename vtype::zmm_t ;
174
- typename vtype::opmask_t load_mask
175
- = ((0x1ull << (N - 32 )) - 0x1ull ) & 0xFFFFFFFF ;
176
- zmm_t zmm1 = vtype::loadu (arr);
177
- zmm_t zmm2 = vtype::mask_loadu (vtype::zmm_max (), load_mask, arr + 32 );
178
- zmm1 = sort_zmm_16bit<vtype>(zmm1);
179
- zmm2 = sort_zmm_16bit<vtype>(zmm2);
180
- bitonic_merge_two_zmm_16bit<vtype>(zmm1, zmm2);
181
- vtype::storeu (arr, zmm1);
182
- vtype::mask_storeu (arr + 32 , load_mask, zmm2);
183
- }
184
-
185
- template <typename vtype, typename type_t >
186
- X86_SIMD_SORT_INLINE void sort_128_16bit (type_t *arr, int32_t N)
187
- {
188
- if (N <= 64 ) {
189
- sort_64_16bit<vtype>(arr, N);
190
- return ;
191
- }
192
- using zmm_t = typename vtype::zmm_t ;
193
- using opmask_t = typename vtype::opmask_t ;
194
- zmm_t zmm[4 ];
195
- zmm[0 ] = vtype::loadu (arr);
196
- zmm[1 ] = vtype::loadu (arr + 32 );
197
- opmask_t load_mask1 = 0xFFFFFFFF , load_mask2 = 0xFFFFFFFF ;
198
- if (N != 128 ) {
199
- uint64_t combined_mask = (0x1ull << (N - 64 )) - 0x1ull ;
200
- load_mask1 = combined_mask & 0xFFFFFFFF ;
201
- load_mask2 = (combined_mask >> 32 ) & 0xFFFFFFFF ;
202
- }
203
- zmm[2 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask1, arr + 64 );
204
- zmm[3 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask2, arr + 96 );
205
- zmm[0 ] = sort_zmm_16bit<vtype>(zmm[0 ]);
206
- zmm[1 ] = sort_zmm_16bit<vtype>(zmm[1 ]);
207
- zmm[2 ] = sort_zmm_16bit<vtype>(zmm[2 ]);
208
- zmm[3 ] = sort_zmm_16bit<vtype>(zmm[3 ]);
209
- bitonic_merge_two_zmm_16bit<vtype>(zmm[0 ], zmm[1 ]);
210
- bitonic_merge_two_zmm_16bit<vtype>(zmm[2 ], zmm[3 ]);
211
- bitonic_merge_four_zmm_16bit<vtype>(zmm);
212
- vtype::storeu (arr, zmm[0 ]);
213
- vtype::storeu (arr + 32 , zmm[1 ]);
214
- vtype::mask_storeu (arr + 64 , load_mask1, zmm[2 ]);
215
- vtype::mask_storeu (arr + 96 , load_mask2, zmm[3 ]);
216
- }
217
-
218
- template <typename vtype, typename type_t >
219
- X86_SIMD_SORT_INLINE type_t get_pivot_16bit (type_t *arr,
220
- const int64_t left,
221
- const int64_t right)
222
- {
223
- // median of 32
224
- int64_t size = (right - left) / 32 ;
225
- type_t vec_arr[32 ] = {arr[left],
226
- arr[left + size],
227
- arr[left + 2 * size],
228
- arr[left + 3 * size],
229
- arr[left + 4 * size],
230
- arr[left + 5 * size],
231
- arr[left + 6 * size],
232
- arr[left + 7 * size],
233
- arr[left + 8 * size],
234
- arr[left + 9 * size],
235
- arr[left + 10 * size],
236
- arr[left + 11 * size],
237
- arr[left + 12 * size],
238
- arr[left + 13 * size],
239
- arr[left + 14 * size],
240
- arr[left + 15 * size],
241
- arr[left + 16 * size],
242
- arr[left + 17 * size],
243
- arr[left + 18 * size],
244
- arr[left + 19 * size],
245
- arr[left + 20 * size],
246
- arr[left + 21 * size],
247
- arr[left + 22 * size],
248
- arr[left + 23 * size],
249
- arr[left + 24 * size],
250
- arr[left + 25 * size],
251
- arr[left + 26 * size],
252
- arr[left + 27 * size],
253
- arr[left + 28 * size],
254
- arr[left + 29 * size],
255
- arr[left + 30 * size],
256
- arr[left + 31 * size]};
257
- typename vtype::zmm_t rand_vec = vtype::loadu (vec_arr);
258
- typename vtype::zmm_t sort = sort_zmm_16bit<vtype>(rand_vec);
259
- return ((type_t *)&sort)[16 ];
260
- }
261
-
262
- template <typename vtype, typename type_t >
263
- static void
264
- qsort_16bit_ (type_t *arr, int64_t left, int64_t right, int64_t max_iters)
265
- {
266
- /*
267
- * Resort to std::sort if quicksort isnt making any progress
268
- */
269
- if (max_iters <= 0 ) {
270
- std::sort (arr + left, arr + right + 1 , comparison_func<vtype>);
271
- return ;
272
- }
273
- /*
274
- * Base case: use bitonic networks to sort arrays <= 128
275
- */
276
- if (right + 1 - left <= 128 ) {
277
- sort_128_16bit<vtype>(arr + left, (int32_t )(right + 1 - left));
278
- return ;
279
- }
280
-
281
- type_t pivot = get_pivot_16bit<vtype>(arr, left, right);
282
- type_t smallest = vtype::type_max ();
283
- type_t biggest = vtype::type_min ();
284
- int64_t pivot_index = partition_avx512<vtype>(
285
- arr, left, right + 1 , pivot, &smallest, &biggest);
286
- if (pivot != smallest)
287
- qsort_16bit_<vtype>(arr, left, pivot_index - 1 , max_iters - 1 );
288
- if (pivot != biggest)
289
- qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1 );
290
- }
291
-
292
- template <typename vtype, typename type_t >
293
- static void qselect_16bit_ (type_t *arr,
294
- int64_t pos,
295
- int64_t left,
296
- int64_t right,
297
- int64_t max_iters)
298
- {
299
- /*
300
- * Resort to std::sort if quicksort isnt making any progress
301
- */
302
- if (max_iters <= 0 ) {
303
- std::sort (arr + left, arr + right + 1 , comparison_func<vtype>);
304
- return ;
305
- }
306
- /*
307
- * Base case: use bitonic networks to sort arrays <= 128
308
- */
309
- if (right + 1 - left <= 128 ) {
310
- sort_128_16bit<vtype>(arr + left, (int32_t )(right + 1 - left));
311
- return ;
312
- }
313
-
314
- type_t pivot = get_pivot_16bit<vtype>(arr, left, right);
315
- type_t smallest = vtype::type_max ();
316
- type_t biggest = vtype::type_min ();
317
- int64_t pivot_index = partition_avx512<vtype>(
318
- arr, left, right + 1 , pivot, &smallest, &biggest);
319
- if ((pivot != smallest) && (pos < pivot_index))
320
- qselect_16bit_<vtype>(arr, pos, left, pivot_index - 1 , max_iters - 1 );
321
- else if ((pivot != biggest) && (pos >= pivot_index))
322
- qselect_16bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1 );
323
- }
324
-
325
122
#endif // AVX512_16BIT_COMMON
0 commit comments