Skip to content

Commit dc57afb

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #16 from r-devulap/avx512fp16
Add qsort for _Float16 using AVX-512 FP16 ISA
2 parents 58501d0 + 3c21c7f commit dc57afb

14 files changed

+603
-353
lines changed

Makefile

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
CXX ?= g++
1+
CXX = g++-12
22
SRCDIR = ./src
33
TESTDIR = ./tests
44
BENCHDIR = ./benchmarks
55
UTILS = ./utils
66
SRCS = $(wildcard $(SRCDIR)/*.hpp)
77
TESTS = $(wildcard $(TESTDIR)/*.cpp)
88
TESTOBJS = $(patsubst $(TESTDIR)/%.cpp,$(TESTDIR)/%.o,$(TESTS))
9-
TESTOBJS := $(filter-out $(TESTDIR)/main.o ,$(TESTOBJS))
109
CXXFLAGS += -I$(SRCDIR) -I$(UTILS)
11-
GTESTCFLAGS = `pkg-config --cflags gtest`
12-
GTESTLDFLAGS = `pkg-config --libs gtest`
13-
MARCHFLAG = -march=icelake-client -O3
10+
GTESTCFLAGS = `pkg-config --cflags gtest_main`
11+
GTESTLDFLAGS = `pkg-config --libs gtest_main`
12+
MARCHFLAG = -march=sapphirerapids -O3
1413

1514
all : test bench
1615

@@ -20,11 +19,15 @@ $(UTILS)/cpuinfo.o : $(UTILS)/cpuinfo.cpp
2019
$(TESTDIR)/%.o : $(TESTDIR)/%.cpp $(SRCS)
2120
$(CXX) $(CXXFLAGS) $(MARCHFLAG) $(GTESTCFLAGS) -c $< -o $@
2221

23-
test: $(TESTDIR)/main.cpp $(TESTOBJS) $(UTILS)/cpuinfo.o $(SRCS)
24-
$(CXX) tests/main.cpp $(TESTOBJS) $(UTILS)/cpuinfo.o $(MARCHFLAG) $(CXXFLAGS) $(GTESTLDFLAGS) -o testexe
22+
test: $(TESTOBJS) $(UTILS)/cpuinfo.o $(SRCS)
23+
$(CXX) $(TESTOBJS) $(UTILS)/cpuinfo.o $(MARCHFLAG) $(CXXFLAGS) -lgtest_main $(GTESTLDFLAGS) -o testexe
2524

2625
bench: $(BENCHDIR)/main.cpp $(SRCS) $(UTILS)/cpuinfo.o
2726
$(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) $(UTILS)/cpuinfo.o $(MARCHFLAG) -o benchexe
2827

28+
meson:
29+
meson setup --warnlevel 0 --buildtype plain builddir
30+
cd builddir && ninja
31+
2932
clean:
30-
rm -f $(TESTDIR)/*.o testexe benchexe
33+
$(RM) -rf $(TESTDIR)/*.o $(UTILS)/*.o testexe benchexe builddir

meson.build

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,24 @@
1-
project('x86-simd-sort', 'c', 'cpp',
1+
project('x86-simd-sort', 'cpp',
22
version : '1.0.0',
33
license : 'BSD 3-clause')
4-
cc = meson.get_compiler('c')
54
cpp = meson.get_compiler('cpp')
6-
src = include_directories('./src')
7-
bench = include_directories('./benchmarks')
8-
utils = include_directories('./utils')
9-
tests = include_directories('./tests')
10-
gtest_dep = dependency('gtest', fallback : ['gtest', 'gtest_dep'])
11-
subdir('./tests')
5+
src = include_directories('src')
6+
bench = include_directories('benchmarks')
7+
utils = include_directories('utils')
8+
tests = include_directories('tests')
9+
gtest_dep = dependency('gtest_main', required : true)
10+
subdir('utils')
11+
subdir('tests')
1212

13-
testexe = executable('testexe', 'tests/main.cpp',
13+
testexe = executable('testexe',
14+
include_directories : [src, utils],
1415
dependencies : gtest_dep,
15-
link_whole : [
16-
libtests,
17-
]
18-
)
16+
link_whole : [libtests, libcpuinfo]
17+
)
1918

2019
benchexe = executable('benchexe', 'benchmarks/main.cpp',
21-
include_directories : [
22-
src,
23-
utils,
24-
bench,
25-
],
26-
cpp_args : [
27-
'-O3',
28-
'-march=icelake-client',
29-
],
30-
dependencies : [],
31-
link_whole : [],
20+
include_directories : [src, utils, bench],
21+
cpp_args : [ '-O3', '-march=icelake-client' ],
22+
dependencies : [],
23+
link_whole : [libcpuinfo],
3224
)

src/avx512-16bit-common.h

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
/*******************************************************************
2+
* Copyright (C) 2022 Intel Corporation
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
* Authors: Raghuveer Devulapalli <[email protected]>
5+
* ****************************************************************/
6+
7+
#ifndef AVX512_16BIT_COMMON
8+
#define AVX512_16BIT_COMMON
9+
10+
#include "avx512-common-qsort.h"
11+
12+
/*
13+
* Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic
14+
* sorting network (see
15+
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
16+
*/
17+
// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0
18+
static const uint16_t network[6][32]
19+
= {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8,
20+
23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24},
21+
{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
22+
31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16},
23+
{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11,
24+
20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27},
25+
{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,
26+
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
27+
{8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7,
28+
24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23},
29+
{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
30+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}};
31+
32+
/*
33+
* Assumes zmm is random and performs a full sorting network defined in
34+
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
35+
*/
36+
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
37+
X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit(zmm_t zmm)
38+
{
39+
// Level 1
40+
zmm = cmp_merge<vtype>(
41+
zmm,
42+
vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm),
43+
0xAAAAAAAA);
44+
// Level 2
45+
zmm = cmp_merge<vtype>(
46+
zmm,
47+
vtype::template shuffle<SHUFFLE_MASK(0, 1, 2, 3)>(zmm),
48+
0xCCCCCCCC);
49+
zmm = cmp_merge<vtype>(
50+
zmm,
51+
vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm),
52+
0xAAAAAAAA);
53+
// Level 3
54+
zmm = cmp_merge<vtype>(
55+
zmm, vtype::permutexvar(vtype::get_network(1), zmm), 0xF0F0F0F0);
56+
zmm = cmp_merge<vtype>(
57+
zmm,
58+
vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm),
59+
0xCCCCCCCC);
60+
zmm = cmp_merge<vtype>(
61+
zmm,
62+
vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm),
63+
0xAAAAAAAA);
64+
// Level 4
65+
zmm = cmp_merge<vtype>(
66+
zmm, vtype::permutexvar(vtype::get_network(2), zmm), 0xFF00FF00);
67+
zmm = cmp_merge<vtype>(
68+
zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0);
69+
zmm = cmp_merge<vtype>(
70+
zmm,
71+
vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm),
72+
0xCCCCCCCC);
73+
zmm = cmp_merge<vtype>(
74+
zmm,
75+
vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm),
76+
0xAAAAAAAA);
77+
// Level 5
78+
zmm = cmp_merge<vtype>(
79+
zmm, vtype::permutexvar(vtype::get_network(4), zmm), 0xFFFF0000);
80+
zmm = cmp_merge<vtype>(
81+
zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00);
82+
zmm = cmp_merge<vtype>(
83+
zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0);
84+
zmm = cmp_merge<vtype>(
85+
zmm,
86+
vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm),
87+
0xCCCCCCCC);
88+
zmm = cmp_merge<vtype>(
89+
zmm,
90+
vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm),
91+
0xAAAAAAAA);
92+
return zmm;
93+
}
94+
95+
// Assumes zmm is bitonic and performs a recursive half cleaner
96+
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
97+
X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
98+
{
99+
// 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc ..
100+
zmm = cmp_merge<vtype>(
101+
zmm, vtype::permutexvar(vtype::get_network(6), zmm), 0xFFFF0000);
102+
// 2) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc ..
103+
zmm = cmp_merge<vtype>(
104+
zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00);
105+
// 3) half_cleaner[8]
106+
zmm = cmp_merge<vtype>(
107+
zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0);
108+
// 3) half_cleaner[4]
109+
zmm = cmp_merge<vtype>(
110+
zmm,
111+
vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm),
112+
0xCCCCCCCC);
113+
// 3) half_cleaner[2]
114+
zmm = cmp_merge<vtype>(
115+
zmm,
116+
vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm),
117+
0xAAAAAAAA);
118+
return zmm;
119+
}
120+
121+
// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
122+
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
123+
X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2)
124+
{
125+
// 1) First step of a merging network: coex of zmm1 and zmm2 reversed
126+
zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2);
127+
zmm_t zmm3 = vtype::min(zmm1, zmm2);
128+
zmm_t zmm4 = vtype::max(zmm1, zmm2);
129+
// 2) Recursive half cleaner for each
130+
zmm1 = bitonic_merge_zmm_16bit<vtype>(zmm3);
131+
zmm2 = bitonic_merge_zmm_16bit<vtype>(zmm4);
132+
}
133+
134+
// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
135+
// half cleaner
136+
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
137+
X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
138+
{
139+
zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]);
140+
zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]);
141+
zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r);
142+
zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r);
143+
zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4),
144+
vtype::max(zmm[1], zmm2r));
145+
zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4),
146+
vtype::max(zmm[0], zmm3r));
147+
zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2);
148+
zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2);
149+
zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4);
150+
zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4);
151+
zmm[0] = bitonic_merge_zmm_16bit<vtype>(zmm0);
152+
zmm[1] = bitonic_merge_zmm_16bit<vtype>(zmm1);
153+
zmm[2] = bitonic_merge_zmm_16bit<vtype>(zmm2);
154+
zmm[3] = bitonic_merge_zmm_16bit<vtype>(zmm3);
155+
}
156+
157+
template <typename vtype, typename type_t>
158+
X86_SIMD_SORT_INLINE void sort_32_16bit(type_t *arr, int32_t N)
159+
{
160+
typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull) & 0xFFFFFFFF;
161+
typename vtype::zmm_t zmm
162+
= vtype::mask_loadu(vtype::zmm_max(), load_mask, arr);
163+
vtype::mask_storeu(arr, load_mask, sort_zmm_16bit<vtype>(zmm));
164+
}
165+
166+
template <typename vtype, typename type_t>
167+
X86_SIMD_SORT_INLINE void sort_64_16bit(type_t *arr, int32_t N)
168+
{
169+
if (N <= 32) {
170+
sort_32_16bit<vtype>(arr, N);
171+
return;
172+
}
173+
using zmm_t = typename vtype::zmm_t;
174+
typename vtype::opmask_t load_mask
175+
= ((0x1ull << (N - 32)) - 0x1ull) & 0xFFFFFFFF;
176+
zmm_t zmm1 = vtype::loadu(arr);
177+
zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 32);
178+
zmm1 = sort_zmm_16bit<vtype>(zmm1);
179+
zmm2 = sort_zmm_16bit<vtype>(zmm2);
180+
bitonic_merge_two_zmm_16bit<vtype>(zmm1, zmm2);
181+
vtype::storeu(arr, zmm1);
182+
vtype::mask_storeu(arr + 32, load_mask, zmm2);
183+
}
184+
185+
template <typename vtype, typename type_t>
186+
X86_SIMD_SORT_INLINE void sort_128_16bit(type_t *arr, int32_t N)
187+
{
188+
if (N <= 64) {
189+
sort_64_16bit<vtype>(arr, N);
190+
return;
191+
}
192+
using zmm_t = typename vtype::zmm_t;
193+
using opmask_t = typename vtype::opmask_t;
194+
zmm_t zmm[4];
195+
zmm[0] = vtype::loadu(arr);
196+
zmm[1] = vtype::loadu(arr + 32);
197+
opmask_t load_mask1 = 0xFFFFFFFF, load_mask2 = 0xFFFFFFFF;
198+
if (N != 128) {
199+
uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull;
200+
load_mask1 = combined_mask & 0xFFFFFFFF;
201+
load_mask2 = (combined_mask >> 32) & 0xFFFFFFFF;
202+
}
203+
zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64);
204+
zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 96);
205+
zmm[0] = sort_zmm_16bit<vtype>(zmm[0]);
206+
zmm[1] = sort_zmm_16bit<vtype>(zmm[1]);
207+
zmm[2] = sort_zmm_16bit<vtype>(zmm[2]);
208+
zmm[3] = sort_zmm_16bit<vtype>(zmm[3]);
209+
bitonic_merge_two_zmm_16bit<vtype>(zmm[0], zmm[1]);
210+
bitonic_merge_two_zmm_16bit<vtype>(zmm[2], zmm[3]);
211+
bitonic_merge_four_zmm_16bit<vtype>(zmm);
212+
vtype::storeu(arr, zmm[0]);
213+
vtype::storeu(arr + 32, zmm[1]);
214+
vtype::mask_storeu(arr + 64, load_mask1, zmm[2]);
215+
vtype::mask_storeu(arr + 96, load_mask2, zmm[3]);
216+
}
217+
218+
template <typename vtype, typename type_t>
219+
X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
220+
const int64_t left,
221+
const int64_t right)
222+
{
223+
// median of 32
224+
int64_t size = (right - left) / 32;
225+
type_t vec_arr[32] = {arr[left],
226+
arr[left + size],
227+
arr[left + 2 * size],
228+
arr[left + 3 * size],
229+
arr[left + 4 * size],
230+
arr[left + 5 * size],
231+
arr[left + 6 * size],
232+
arr[left + 7 * size],
233+
arr[left + 8 * size],
234+
arr[left + 9 * size],
235+
arr[left + 10 * size],
236+
arr[left + 11 * size],
237+
arr[left + 12 * size],
238+
arr[left + 13 * size],
239+
arr[left + 14 * size],
240+
arr[left + 15 * size],
241+
arr[left + 16 * size],
242+
arr[left + 17 * size],
243+
arr[left + 18 * size],
244+
arr[left + 19 * size],
245+
arr[left + 20 * size],
246+
arr[left + 21 * size],
247+
arr[left + 22 * size],
248+
arr[left + 23 * size],
249+
arr[left + 24 * size],
250+
arr[left + 25 * size],
251+
arr[left + 26 * size],
252+
arr[left + 27 * size],
253+
arr[left + 28 * size],
254+
arr[left + 29 * size],
255+
arr[left + 30 * size],
256+
arr[left + 31 * size]};
257+
typename vtype::zmm_t rand_vec = vtype::loadu(vec_arr);
258+
typename vtype::zmm_t sort = sort_zmm_16bit<vtype>(rand_vec);
259+
return ((type_t *)&sort)[16];
260+
}
261+
262+
template <typename vtype, typename type_t>
263+
static void
264+
qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
265+
{
266+
/*
267+
* Resort to std::sort if quicksort isnt making any progress
268+
*/
269+
if (max_iters <= 0) {
270+
std::sort(arr + left, arr + right + 1, comparison_func<vtype>);
271+
return;
272+
}
273+
/*
274+
* Base case: use bitonic networks to sort arrays <= 128
275+
*/
276+
if (right + 1 - left <= 128) {
277+
sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left));
278+
return;
279+
}
280+
281+
type_t pivot = get_pivot_16bit<vtype>(arr, left, right);
282+
type_t smallest = vtype::type_max();
283+
type_t biggest = vtype::type_min();
284+
int64_t pivot_index = partition_avx512<vtype>(
285+
arr, left, right + 1, pivot, &smallest, &biggest);
286+
if (pivot != smallest)
287+
qsort_16bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
288+
if (pivot != biggest)
289+
qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1);
290+
}
291+
292+
#endif // AVX512_16BIT_COMMON

0 commit comments

Comments
 (0)