|
8 | 8 | #define AVX512_16BIT_COMMON
|
9 | 9 |
|
10 | 10 | #include "avx512-common-qsort.h"
|
11 |
| -#include "xss-network-qsort.hpp" |
12 | 11 |
|
13 | 12 | /*
|
14 | 13 | * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic
|
@@ -93,30 +92,221 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit(reg_t zmm)
|
93 | 92 | return zmm;
|
94 | 93 | }
|
95 | 94 |
|
96 |
| -// Assumes zmm is bitonic and performs a recursive half cleaner |
97 |
| -template <typename vtype, typename reg_t = typename vtype::reg_t> |
98 |
| -X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_16bit(reg_t zmm) |
99 |
| -{ |
100 |
| - // 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc .. |
101 |
| - zmm = cmp_merge<vtype>( |
102 |
| - zmm, vtype::permutexvar(vtype::get_network(6), zmm), 0xFFFF0000); |
103 |
| - // 2) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc .. |
104 |
| - zmm = cmp_merge<vtype>( |
105 |
| - zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00); |
106 |
| - // 3) half_cleaner[8] |
107 |
| - zmm = cmp_merge<vtype>( |
108 |
| - zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); |
109 |
| - // 3) half_cleaner[4] |
110 |
| - zmm = cmp_merge<vtype>( |
111 |
| - zmm, |
112 |
| - vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm), |
113 |
| - 0xCCCCCCCC); |
114 |
| - // 3) half_cleaner[2] |
115 |
| - zmm = cmp_merge<vtype>( |
116 |
| - zmm, |
117 |
| - vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm), |
118 |
| - 0xAAAAAAAA); |
119 |
| - return zmm; |
120 |
| -} |
| 95 | +struct avx512_16bit_swizzle_ops { |
| 96 | + template <typename vtype, int scale> |
| 97 | + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) |
| 98 | + { |
| 99 | + __m512i v = vtype::cast_to(reg); |
| 100 | + |
| 101 | + if constexpr (scale == 2) { |
| 102 | + __m512i mask = _mm512_set_epi16(30, |
| 103 | + 31, |
| 104 | + 28, |
| 105 | + 29, |
| 106 | + 26, |
| 107 | + 27, |
| 108 | + 24, |
| 109 | + 25, |
| 110 | + 22, |
| 111 | + 23, |
| 112 | + 20, |
| 113 | + 21, |
| 114 | + 18, |
| 115 | + 19, |
| 116 | + 16, |
| 117 | + 17, |
| 118 | + 14, |
| 119 | + 15, |
| 120 | + 12, |
| 121 | + 13, |
| 122 | + 10, |
| 123 | + 11, |
| 124 | + 8, |
| 125 | + 9, |
| 126 | + 6, |
| 127 | + 7, |
| 128 | + 4, |
| 129 | + 5, |
| 130 | + 2, |
| 131 | + 3, |
| 132 | + 0, |
| 133 | + 1); |
| 134 | + v = _mm512_permutexvar_epi16(mask, v); |
| 135 | + } |
| 136 | + else if constexpr (scale == 4) { |
| 137 | + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b10110001); |
| 138 | + } |
| 139 | + else if constexpr (scale == 8) { |
| 140 | + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b01001110); |
| 141 | + } |
| 142 | + else if constexpr (scale == 16) { |
| 143 | + v = _mm512_shuffle_i64x2(v, v, 0b10110001); |
| 144 | + } |
| 145 | + else if constexpr (scale == 32) { |
| 146 | + v = _mm512_shuffle_i64x2(v, v, 0b01001110); |
| 147 | + } |
| 148 | + else { |
| 149 | + static_assert(scale == -1, "should not be reached"); |
| 150 | + } |
| 151 | + |
| 152 | + return vtype::cast_from(v); |
| 153 | + } |
| 154 | + |
| 155 | + template <typename vtype, int scale> |
| 156 | + X86_SIMD_SORT_INLINE typename vtype::reg_t |
| 157 | + reverse_n(typename vtype::reg_t reg) |
| 158 | + { |
| 159 | + __m512i v = vtype::cast_to(reg); |
| 160 | + |
| 161 | + if constexpr (scale == 2) { return swap_n<vtype, 2>(reg); } |
| 162 | + else if constexpr (scale == 4) { |
| 163 | + __m512i mask = _mm512_set_epi16(28, |
| 164 | + 29, |
| 165 | + 30, |
| 166 | + 31, |
| 167 | + 24, |
| 168 | + 25, |
| 169 | + 26, |
| 170 | + 27, |
| 171 | + 20, |
| 172 | + 21, |
| 173 | + 22, |
| 174 | + 23, |
| 175 | + 16, |
| 176 | + 17, |
| 177 | + 18, |
| 178 | + 19, |
| 179 | + 12, |
| 180 | + 13, |
| 181 | + 14, |
| 182 | + 15, |
| 183 | + 8, |
| 184 | + 9, |
| 185 | + 10, |
| 186 | + 11, |
| 187 | + 4, |
| 188 | + 5, |
| 189 | + 6, |
| 190 | + 7, |
| 191 | + 0, |
| 192 | + 1, |
| 193 | + 2, |
| 194 | + 3); |
| 195 | + v = _mm512_permutexvar_epi16(mask, v); |
| 196 | + } |
| 197 | + else if constexpr (scale == 8) { |
| 198 | + __m512i mask = _mm512_set_epi16(24, |
| 199 | + 25, |
| 200 | + 26, |
| 201 | + 27, |
| 202 | + 28, |
| 203 | + 29, |
| 204 | + 30, |
| 205 | + 31, |
| 206 | + 16, |
| 207 | + 17, |
| 208 | + 18, |
| 209 | + 19, |
| 210 | + 20, |
| 211 | + 21, |
| 212 | + 22, |
| 213 | + 23, |
| 214 | + 8, |
| 215 | + 9, |
| 216 | + 10, |
| 217 | + 11, |
| 218 | + 12, |
| 219 | + 13, |
| 220 | + 14, |
| 221 | + 15, |
| 222 | + 0, |
| 223 | + 1, |
| 224 | + 2, |
| 225 | + 3, |
| 226 | + 4, |
| 227 | + 5, |
| 228 | + 6, |
| 229 | + 7); |
| 230 | + v = _mm512_permutexvar_epi16(mask, v); |
| 231 | + } |
| 232 | + else if constexpr (scale == 16) { |
| 233 | + __m512i mask = _mm512_set_epi16(16, |
| 234 | + 17, |
| 235 | + 18, |
| 236 | + 19, |
| 237 | + 20, |
| 238 | + 21, |
| 239 | + 22, |
| 240 | + 23, |
| 241 | + 24, |
| 242 | + 25, |
| 243 | + 26, |
| 244 | + 27, |
| 245 | + 28, |
| 246 | + 29, |
| 247 | + 30, |
| 248 | + 31, |
| 249 | + 0, |
| 250 | + 1, |
| 251 | + 2, |
| 252 | + 3, |
| 253 | + 4, |
| 254 | + 5, |
| 255 | + 6, |
| 256 | + 7, |
| 257 | + 8, |
| 258 | + 9, |
| 259 | + 10, |
| 260 | + 11, |
| 261 | + 12, |
| 262 | + 13, |
| 263 | + 14, |
| 264 | + 15); |
| 265 | + v = _mm512_permutexvar_epi16(mask, v); |
| 266 | + } |
| 267 | + else if constexpr (scale == 32) { |
| 268 | + return vtype::reverse(reg); |
| 269 | + } |
| 270 | + else { |
| 271 | + static_assert(scale == -1, "should not be reached"); |
| 272 | + } |
| 273 | + |
| 274 | + return vtype::cast_from(v); |
| 275 | + } |
| 276 | + |
| 277 | + template <typename vtype, int scale> |
| 278 | + X86_SIMD_SORT_INLINE typename vtype::reg_t |
| 279 | + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) |
| 280 | + { |
| 281 | + __m512i v1 = vtype::cast_to(reg); |
| 282 | + __m512i v2 = vtype::cast_to(other); |
| 283 | + |
| 284 | + if constexpr (scale == 2) { |
| 285 | + v1 = _mm512_mask_blend_epi16( |
| 286 | + 0b01010101010101010101010101010101, v1, v2); |
| 287 | + } |
| 288 | + else if constexpr (scale == 4) { |
| 289 | + v1 = _mm512_mask_blend_epi16( |
| 290 | + 0b00110011001100110011001100110011, v1, v2); |
| 291 | + } |
| 292 | + else if constexpr (scale == 8) { |
| 293 | + v1 = _mm512_mask_blend_epi16( |
| 294 | + 0b00001111000011110000111100001111, v1, v2); |
| 295 | + } |
| 296 | + else if constexpr (scale == 16) { |
| 297 | + v1 = _mm512_mask_blend_epi16( |
| 298 | + 0b00000000111111110000000011111111, v1, v2); |
| 299 | + } |
| 300 | + else if constexpr (scale == 32) { |
| 301 | + v1 = _mm512_mask_blend_epi16( |
| 302 | + 0b00000000000000001111111111111111, v1, v2); |
| 303 | + } |
| 304 | + else { |
| 305 | + static_assert(scale == -1, "should not be reached"); |
| 306 | + } |
| 307 | + |
| 308 | + return vtype::cast_from(v1); |
| 309 | + } |
| 310 | +}; |
121 | 311 |
|
122 | 312 | #endif // AVX512_16BIT_COMMON
|
0 commit comments