Skip to content

Commit 1b39637

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #61 from sterrettm2/generic_nets
Changes quicksort and quickselect to use template based sorting networks
2 parents b9f9340 + cea0a72 commit 1b39637

12 files changed

+1000
-1902
lines changed

src/avx512-16bit-common.h

Lines changed: 5 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define AVX512_16BIT_COMMON
99

1010
#include "avx512-common-qsort.h"
11+
#include "xss-network-qsort.hpp"
1112

1213
/*
1314
* Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic
@@ -33,8 +34,8 @@ static const uint16_t network[6][32]
3334
* Assumes zmm is random and performs a full sorting network defined in
3435
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
3536
*/
36-
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
37-
X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit(zmm_t zmm)
37+
template <typename vtype, typename reg_t = typename vtype::reg_t>
38+
X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit(reg_t zmm)
3839
{
3940
// Level 1
4041
zmm = cmp_merge<vtype>(
@@ -93,8 +94,8 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit(zmm_t zmm)
9394
}
9495

9596
// Assumes zmm is bitonic and performs a recursive half cleaner
96-
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
97-
X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
97+
template <typename vtype, typename reg_t = typename vtype::reg_t>
98+
X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_16bit(reg_t zmm)
9899
{
99100
// 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc ..
100101
zmm = cmp_merge<vtype>(
@@ -118,208 +119,4 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
118119
return zmm;
119120
}
120121

121-
// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
122-
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
123-
X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2)
124-
{
125-
// 1) First step of a merging network: coex of zmm1 and zmm2 reversed
126-
zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2);
127-
zmm_t zmm3 = vtype::min(zmm1, zmm2);
128-
zmm_t zmm4 = vtype::max(zmm1, zmm2);
129-
// 2) Recursive half cleaner for each
130-
zmm1 = bitonic_merge_zmm_16bit<vtype>(zmm3);
131-
zmm2 = bitonic_merge_zmm_16bit<vtype>(zmm4);
132-
}
133-
134-
// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
135-
// half cleaner
136-
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
137-
X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
138-
{
139-
zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]);
140-
zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]);
141-
zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r);
142-
zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r);
143-
zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4),
144-
vtype::max(zmm[1], zmm2r));
145-
zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4),
146-
vtype::max(zmm[0], zmm3r));
147-
zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2);
148-
zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2);
149-
zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4);
150-
zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4);
151-
zmm[0] = bitonic_merge_zmm_16bit<vtype>(zmm0);
152-
zmm[1] = bitonic_merge_zmm_16bit<vtype>(zmm1);
153-
zmm[2] = bitonic_merge_zmm_16bit<vtype>(zmm2);
154-
zmm[3] = bitonic_merge_zmm_16bit<vtype>(zmm3);
155-
}
156-
157-
template <typename vtype, typename type_t>
158-
X86_SIMD_SORT_INLINE void sort_32_16bit(type_t *arr, int32_t N)
159-
{
160-
typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull) & 0xFFFFFFFF;
161-
typename vtype::zmm_t zmm
162-
= vtype::mask_loadu(vtype::zmm_max(), load_mask, arr);
163-
vtype::mask_storeu(arr, load_mask, sort_zmm_16bit<vtype>(zmm));
164-
}
165-
166-
template <typename vtype, typename type_t>
167-
X86_SIMD_SORT_INLINE void sort_64_16bit(type_t *arr, int32_t N)
168-
{
169-
if (N <= 32) {
170-
sort_32_16bit<vtype>(arr, N);
171-
return;
172-
}
173-
using zmm_t = typename vtype::zmm_t;
174-
typename vtype::opmask_t load_mask
175-
= ((0x1ull << (N - 32)) - 0x1ull) & 0xFFFFFFFF;
176-
zmm_t zmm1 = vtype::loadu(arr);
177-
zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 32);
178-
zmm1 = sort_zmm_16bit<vtype>(zmm1);
179-
zmm2 = sort_zmm_16bit<vtype>(zmm2);
180-
bitonic_merge_two_zmm_16bit<vtype>(zmm1, zmm2);
181-
vtype::storeu(arr, zmm1);
182-
vtype::mask_storeu(arr + 32, load_mask, zmm2);
183-
}
184-
185-
template <typename vtype, typename type_t>
186-
X86_SIMD_SORT_INLINE void sort_128_16bit(type_t *arr, int32_t N)
187-
{
188-
if (N <= 64) {
189-
sort_64_16bit<vtype>(arr, N);
190-
return;
191-
}
192-
using zmm_t = typename vtype::zmm_t;
193-
using opmask_t = typename vtype::opmask_t;
194-
zmm_t zmm[4];
195-
zmm[0] = vtype::loadu(arr);
196-
zmm[1] = vtype::loadu(arr + 32);
197-
opmask_t load_mask1 = 0xFFFFFFFF, load_mask2 = 0xFFFFFFFF;
198-
if (N != 128) {
199-
uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull;
200-
load_mask1 = combined_mask & 0xFFFFFFFF;
201-
load_mask2 = (combined_mask >> 32) & 0xFFFFFFFF;
202-
}
203-
zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64);
204-
zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 96);
205-
zmm[0] = sort_zmm_16bit<vtype>(zmm[0]);
206-
zmm[1] = sort_zmm_16bit<vtype>(zmm[1]);
207-
zmm[2] = sort_zmm_16bit<vtype>(zmm[2]);
208-
zmm[3] = sort_zmm_16bit<vtype>(zmm[3]);
209-
bitonic_merge_two_zmm_16bit<vtype>(zmm[0], zmm[1]);
210-
bitonic_merge_two_zmm_16bit<vtype>(zmm[2], zmm[3]);
211-
bitonic_merge_four_zmm_16bit<vtype>(zmm);
212-
vtype::storeu(arr, zmm[0]);
213-
vtype::storeu(arr + 32, zmm[1]);
214-
vtype::mask_storeu(arr + 64, load_mask1, zmm[2]);
215-
vtype::mask_storeu(arr + 96, load_mask2, zmm[3]);
216-
}
217-
218-
template <typename vtype, typename type_t>
219-
X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
220-
const int64_t left,
221-
const int64_t right)
222-
{
223-
// median of 32
224-
int64_t size = (right - left) / 32;
225-
type_t vec_arr[32] = {arr[left],
226-
arr[left + size],
227-
arr[left + 2 * size],
228-
arr[left + 3 * size],
229-
arr[left + 4 * size],
230-
arr[left + 5 * size],
231-
arr[left + 6 * size],
232-
arr[left + 7 * size],
233-
arr[left + 8 * size],
234-
arr[left + 9 * size],
235-
arr[left + 10 * size],
236-
arr[left + 11 * size],
237-
arr[left + 12 * size],
238-
arr[left + 13 * size],
239-
arr[left + 14 * size],
240-
arr[left + 15 * size],
241-
arr[left + 16 * size],
242-
arr[left + 17 * size],
243-
arr[left + 18 * size],
244-
arr[left + 19 * size],
245-
arr[left + 20 * size],
246-
arr[left + 21 * size],
247-
arr[left + 22 * size],
248-
arr[left + 23 * size],
249-
arr[left + 24 * size],
250-
arr[left + 25 * size],
251-
arr[left + 26 * size],
252-
arr[left + 27 * size],
253-
arr[left + 28 * size],
254-
arr[left + 29 * size],
255-
arr[left + 30 * size],
256-
arr[left + 31 * size]};
257-
typename vtype::zmm_t rand_vec = vtype::loadu(vec_arr);
258-
typename vtype::zmm_t sort = sort_zmm_16bit<vtype>(rand_vec);
259-
return ((type_t *)&sort)[16];
260-
}
261-
262-
template <typename vtype, typename type_t>
263-
static void
264-
qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
265-
{
266-
/*
267-
* Resort to std::sort if quicksort isnt making any progress
268-
*/
269-
if (max_iters <= 0) {
270-
std::sort(arr + left, arr + right + 1, comparison_func<vtype>);
271-
return;
272-
}
273-
/*
274-
* Base case: use bitonic networks to sort arrays <= 128
275-
*/
276-
if (right + 1 - left <= 128) {
277-
sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left));
278-
return;
279-
}
280-
281-
type_t pivot = get_pivot_16bit<vtype>(arr, left, right);
282-
type_t smallest = vtype::type_max();
283-
type_t biggest = vtype::type_min();
284-
int64_t pivot_index = partition_avx512<vtype>(
285-
arr, left, right + 1, pivot, &smallest, &biggest);
286-
if (pivot != smallest)
287-
qsort_16bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
288-
if (pivot != biggest)
289-
qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1);
290-
}
291-
292-
template <typename vtype, typename type_t>
293-
static void qselect_16bit_(type_t *arr,
294-
int64_t pos,
295-
int64_t left,
296-
int64_t right,
297-
int64_t max_iters)
298-
{
299-
/*
300-
* Resort to std::sort if quicksort isnt making any progress
301-
*/
302-
if (max_iters <= 0) {
303-
std::sort(arr + left, arr + right + 1, comparison_func<vtype>);
304-
return;
305-
}
306-
/*
307-
* Base case: use bitonic networks to sort arrays <= 128
308-
*/
309-
if (right + 1 - left <= 128) {
310-
sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left));
311-
return;
312-
}
313-
314-
type_t pivot = get_pivot_16bit<vtype>(arr, left, right);
315-
type_t smallest = vtype::type_max();
316-
type_t biggest = vtype::type_min();
317-
int64_t pivot_index = partition_avx512<vtype>(
318-
arr, left, right + 1, pivot, &smallest, &biggest);
319-
if ((pivot != smallest) && (pos < pivot_index))
320-
qselect_16bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
321-
else if ((pivot != biggest) && (pos >= pivot_index))
322-
qselect_16bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
323-
}
324-
325122
#endif // AVX512_16BIT_COMMON

0 commit comments

Comments
 (0)