Skip to content

Commit ab29478

Browse files
author
Raghuveer Devulapalli
committed
Add bitonic sorting network of size 256 for 64-bit dtype
1 parent 42ac761 commit ab29478

File tree

1 file changed

+348
-2
lines changed

1 file changed

+348
-2
lines changed

src/avx512-64bit-qsort.hpp

Lines changed: 348 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,161 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm)
179179
zmm[15] = bitonic_merge_zmm_64bit<vtype>(zmm_t16);
180180
}
181181

182+
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
183+
X86_SIMD_SORT_INLINE void bitonic_merge_32_zmm_64bit(zmm_t *zmm)
184+
{
185+
const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2);
186+
zmm_t zmm16r = vtype::permutexvar(rev_index, zmm[16]);
187+
zmm_t zmm17r = vtype::permutexvar(rev_index, zmm[17]);
188+
zmm_t zmm18r = vtype::permutexvar(rev_index, zmm[18]);
189+
zmm_t zmm19r = vtype::permutexvar(rev_index, zmm[19]);
190+
zmm_t zmm20r = vtype::permutexvar(rev_index, zmm[20]);
191+
zmm_t zmm21r = vtype::permutexvar(rev_index, zmm[21]);
192+
zmm_t zmm22r = vtype::permutexvar(rev_index, zmm[22]);
193+
zmm_t zmm23r = vtype::permutexvar(rev_index, zmm[23]);
194+
zmm_t zmm24r = vtype::permutexvar(rev_index, zmm[24]);
195+
zmm_t zmm25r = vtype::permutexvar(rev_index, zmm[25]);
196+
zmm_t zmm26r = vtype::permutexvar(rev_index, zmm[26]);
197+
zmm_t zmm27r = vtype::permutexvar(rev_index, zmm[27]);
198+
zmm_t zmm28r = vtype::permutexvar(rev_index, zmm[28]);
199+
zmm_t zmm29r = vtype::permutexvar(rev_index, zmm[29]);
200+
zmm_t zmm30r = vtype::permutexvar(rev_index, zmm[30]);
201+
zmm_t zmm31r = vtype::permutexvar(rev_index, zmm[31]);
202+
zmm_t zmm_t1 = vtype::min(zmm[0], zmm31r);
203+
zmm_t zmm_t2 = vtype::min(zmm[1], zmm30r);
204+
zmm_t zmm_t3 = vtype::min(zmm[2], zmm29r);
205+
zmm_t zmm_t4 = vtype::min(zmm[3], zmm28r);
206+
zmm_t zmm_t5 = vtype::min(zmm[4], zmm27r);
207+
zmm_t zmm_t6 = vtype::min(zmm[5], zmm26r);
208+
zmm_t zmm_t7 = vtype::min(zmm[6], zmm25r);
209+
zmm_t zmm_t8 = vtype::min(zmm[7], zmm24r);
210+
zmm_t zmm_t9 = vtype::min(zmm[8], zmm23r);
211+
zmm_t zmm_t10 = vtype::min(zmm[9], zmm22r);
212+
zmm_t zmm_t11 = vtype::min(zmm[10], zmm21r);
213+
zmm_t zmm_t12 = vtype::min(zmm[11], zmm20r);
214+
zmm_t zmm_t13 = vtype::min(zmm[12], zmm19r);
215+
zmm_t zmm_t14 = vtype::min(zmm[13], zmm18r);
216+
zmm_t zmm_t15 = vtype::min(zmm[14], zmm17r);
217+
zmm_t zmm_t16 = vtype::min(zmm[15], zmm16r);
218+
zmm_t zmm_t17 = vtype::permutexvar(rev_index, vtype::max(zmm[15], zmm16r));
219+
zmm_t zmm_t18 = vtype::permutexvar(rev_index, vtype::max(zmm[14], zmm17r));
220+
zmm_t zmm_t19 = vtype::permutexvar(rev_index, vtype::max(zmm[13], zmm18r));
221+
zmm_t zmm_t20 = vtype::permutexvar(rev_index, vtype::max(zmm[12], zmm19r));
222+
zmm_t zmm_t21 = vtype::permutexvar(rev_index, vtype::max(zmm[11], zmm20r));
223+
zmm_t zmm_t22 = vtype::permutexvar(rev_index, vtype::max(zmm[10], zmm21r));
224+
zmm_t zmm_t23 = vtype::permutexvar(rev_index, vtype::max(zmm[9], zmm22r));
225+
zmm_t zmm_t24 = vtype::permutexvar(rev_index, vtype::max(zmm[8], zmm23r));
226+
zmm_t zmm_t25 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm24r));
227+
zmm_t zmm_t26 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm25r));
228+
zmm_t zmm_t27 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm26r));
229+
zmm_t zmm_t28 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm27r));
230+
zmm_t zmm_t29 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm28r));
231+
zmm_t zmm_t30 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm29r));
232+
zmm_t zmm_t31 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm30r));
233+
zmm_t zmm_t32 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm31r));
234+
// Recusive half clear 16 zmm regs
235+
COEX<vtype>(zmm_t1, zmm_t9);
236+
COEX<vtype>(zmm_t2, zmm_t10);
237+
COEX<vtype>(zmm_t3, zmm_t11);
238+
COEX<vtype>(zmm_t4, zmm_t12);
239+
COEX<vtype>(zmm_t5, zmm_t13);
240+
COEX<vtype>(zmm_t6, zmm_t14);
241+
COEX<vtype>(zmm_t7, zmm_t15);
242+
COEX<vtype>(zmm_t8, zmm_t16);
243+
COEX<vtype>(zmm_t17, zmm_t25);
244+
COEX<vtype>(zmm_t18, zmm_t26);
245+
COEX<vtype>(zmm_t19, zmm_t27);
246+
COEX<vtype>(zmm_t20, zmm_t28);
247+
COEX<vtype>(zmm_t21, zmm_t29);
248+
COEX<vtype>(zmm_t22, zmm_t30);
249+
COEX<vtype>(zmm_t23, zmm_t31);
250+
COEX<vtype>(zmm_t24, zmm_t32);
251+
//
252+
COEX<vtype>(zmm_t1, zmm_t5);
253+
COEX<vtype>(zmm_t2, zmm_t6);
254+
COEX<vtype>(zmm_t3, zmm_t7);
255+
COEX<vtype>(zmm_t4, zmm_t8);
256+
COEX<vtype>(zmm_t9, zmm_t13);
257+
COEX<vtype>(zmm_t10, zmm_t14);
258+
COEX<vtype>(zmm_t11, zmm_t15);
259+
COEX<vtype>(zmm_t12, zmm_t16);
260+
COEX<vtype>(zmm_t17, zmm_t21);
261+
COEX<vtype>(zmm_t18, zmm_t22);
262+
COEX<vtype>(zmm_t19, zmm_t23);
263+
COEX<vtype>(zmm_t20, zmm_t24);
264+
COEX<vtype>(zmm_t25, zmm_t29);
265+
COEX<vtype>(zmm_t26, zmm_t30);
266+
COEX<vtype>(zmm_t27, zmm_t31);
267+
COEX<vtype>(zmm_t28, zmm_t32);
268+
//
269+
COEX<vtype>(zmm_t1, zmm_t3);
270+
COEX<vtype>(zmm_t2, zmm_t4);
271+
COEX<vtype>(zmm_t5, zmm_t7);
272+
COEX<vtype>(zmm_t6, zmm_t8);
273+
COEX<vtype>(zmm_t9, zmm_t11);
274+
COEX<vtype>(zmm_t10, zmm_t12);
275+
COEX<vtype>(zmm_t13, zmm_t15);
276+
COEX<vtype>(zmm_t14, zmm_t16);
277+
COEX<vtype>(zmm_t17, zmm_t19);
278+
COEX<vtype>(zmm_t18, zmm_t20);
279+
COEX<vtype>(zmm_t21, zmm_t23);
280+
COEX<vtype>(zmm_t22, zmm_t24);
281+
COEX<vtype>(zmm_t25, zmm_t27);
282+
COEX<vtype>(zmm_t26, zmm_t28);
283+
COEX<vtype>(zmm_t29, zmm_t31);
284+
COEX<vtype>(zmm_t30, zmm_t32);
285+
//
286+
COEX<vtype>(zmm_t1, zmm_t2);
287+
COEX<vtype>(zmm_t3, zmm_t4);
288+
COEX<vtype>(zmm_t5, zmm_t6);
289+
COEX<vtype>(zmm_t7, zmm_t8);
290+
COEX<vtype>(zmm_t9, zmm_t10);
291+
COEX<vtype>(zmm_t11, zmm_t12);
292+
COEX<vtype>(zmm_t13, zmm_t14);
293+
COEX<vtype>(zmm_t15, zmm_t16);
294+
COEX<vtype>(zmm_t17, zmm_t18);
295+
COEX<vtype>(zmm_t19, zmm_t20);
296+
COEX<vtype>(zmm_t21, zmm_t22);
297+
COEX<vtype>(zmm_t23, zmm_t24);
298+
COEX<vtype>(zmm_t25, zmm_t26);
299+
COEX<vtype>(zmm_t27, zmm_t28);
300+
COEX<vtype>(zmm_t29, zmm_t30);
301+
COEX<vtype>(zmm_t31, zmm_t32);
302+
//
303+
zmm[0] = bitonic_merge_zmm_64bit<vtype>(zmm_t1);
304+
zmm[1] = bitonic_merge_zmm_64bit<vtype>(zmm_t2);
305+
zmm[2] = bitonic_merge_zmm_64bit<vtype>(zmm_t3);
306+
zmm[3] = bitonic_merge_zmm_64bit<vtype>(zmm_t4);
307+
zmm[4] = bitonic_merge_zmm_64bit<vtype>(zmm_t5);
308+
zmm[5] = bitonic_merge_zmm_64bit<vtype>(zmm_t6);
309+
zmm[6] = bitonic_merge_zmm_64bit<vtype>(zmm_t7);
310+
zmm[7] = bitonic_merge_zmm_64bit<vtype>(zmm_t8);
311+
zmm[8] = bitonic_merge_zmm_64bit<vtype>(zmm_t9);
312+
zmm[9] = bitonic_merge_zmm_64bit<vtype>(zmm_t10);
313+
zmm[10] = bitonic_merge_zmm_64bit<vtype>(zmm_t11);
314+
zmm[11] = bitonic_merge_zmm_64bit<vtype>(zmm_t12);
315+
zmm[12] = bitonic_merge_zmm_64bit<vtype>(zmm_t13);
316+
zmm[13] = bitonic_merge_zmm_64bit<vtype>(zmm_t14);
317+
zmm[14] = bitonic_merge_zmm_64bit<vtype>(zmm_t15);
318+
zmm[15] = bitonic_merge_zmm_64bit<vtype>(zmm_t16);
319+
zmm[16] = bitonic_merge_zmm_64bit<vtype>(zmm_t17);
320+
zmm[17] = bitonic_merge_zmm_64bit<vtype>(zmm_t18);
321+
zmm[18] = bitonic_merge_zmm_64bit<vtype>(zmm_t19);
322+
zmm[19] = bitonic_merge_zmm_64bit<vtype>(zmm_t20);
323+
zmm[20] = bitonic_merge_zmm_64bit<vtype>(zmm_t21);
324+
zmm[21] = bitonic_merge_zmm_64bit<vtype>(zmm_t22);
325+
zmm[22] = bitonic_merge_zmm_64bit<vtype>(zmm_t23);
326+
zmm[23] = bitonic_merge_zmm_64bit<vtype>(zmm_t24);
327+
zmm[24] = bitonic_merge_zmm_64bit<vtype>(zmm_t25);
328+
zmm[25] = bitonic_merge_zmm_64bit<vtype>(zmm_t26);
329+
zmm[26] = bitonic_merge_zmm_64bit<vtype>(zmm_t27);
330+
zmm[27] = bitonic_merge_zmm_64bit<vtype>(zmm_t28);
331+
zmm[28] = bitonic_merge_zmm_64bit<vtype>(zmm_t29);
332+
zmm[29] = bitonic_merge_zmm_64bit<vtype>(zmm_t30);
333+
zmm[30] = bitonic_merge_zmm_64bit<vtype>(zmm_t31);
334+
zmm[31] = bitonic_merge_zmm_64bit<vtype>(zmm_t32);
335+
}
336+
182337
template <typename vtype, typename type_t>
183338
X86_SIMD_SORT_INLINE void sort_8_64bit(type_t *arr, int32_t N)
184339
{
@@ -378,6 +533,197 @@ X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N)
378533
vtype::mask_storeu(arr + 120, load_mask8, zmm[15]);
379534
}
380535

536+
template <typename vtype, typename type_t>
537+
X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N)
538+
{
539+
if (N <= 128) {
540+
sort_128_64bit<vtype>(arr, N);
541+
return;
542+
}
543+
using zmm_t = typename vtype::zmm_t;
544+
using opmask_t = typename vtype::opmask_t;
545+
zmm_t zmm[32];
546+
zmm[0] = vtype::loadu(arr);
547+
zmm[1] = vtype::loadu(arr + 8);
548+
zmm[2] = vtype::loadu(arr + 16);
549+
zmm[3] = vtype::loadu(arr + 24);
550+
zmm[4] = vtype::loadu(arr + 32);
551+
zmm[5] = vtype::loadu(arr + 40);
552+
zmm[6] = vtype::loadu(arr + 48);
553+
zmm[7] = vtype::loadu(arr + 56);
554+
zmm[8] = vtype::loadu(arr + 64);
555+
zmm[9] = vtype::loadu(arr + 72);
556+
zmm[10] = vtype::loadu(arr + 80);
557+
zmm[11] = vtype::loadu(arr + 88);
558+
zmm[12] = vtype::loadu(arr + 96);
559+
zmm[13] = vtype::loadu(arr + 104);
560+
zmm[14] = vtype::loadu(arr + 112);
561+
zmm[15] = vtype::loadu(arr + 120);
562+
zmm[0] = sort_zmm_64bit<vtype>(zmm[0]);
563+
zmm[1] = sort_zmm_64bit<vtype>(zmm[1]);
564+
zmm[2] = sort_zmm_64bit<vtype>(zmm[2]);
565+
zmm[3] = sort_zmm_64bit<vtype>(zmm[3]);
566+
zmm[4] = sort_zmm_64bit<vtype>(zmm[4]);
567+
zmm[5] = sort_zmm_64bit<vtype>(zmm[5]);
568+
zmm[6] = sort_zmm_64bit<vtype>(zmm[6]);
569+
zmm[7] = sort_zmm_64bit<vtype>(zmm[7]);
570+
zmm[8] = sort_zmm_64bit<vtype>(zmm[8]);
571+
zmm[9] = sort_zmm_64bit<vtype>(zmm[9]);
572+
zmm[10] = sort_zmm_64bit<vtype>(zmm[10]);
573+
zmm[11] = sort_zmm_64bit<vtype>(zmm[11]);
574+
zmm[12] = sort_zmm_64bit<vtype>(zmm[12]);
575+
zmm[13] = sort_zmm_64bit<vtype>(zmm[13]);
576+
zmm[14] = sort_zmm_64bit<vtype>(zmm[14]);
577+
zmm[15] = sort_zmm_64bit<vtype>(zmm[15]);
578+
opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF;
579+
opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF;
580+
opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF;
581+
opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF;
582+
opmask_t load_mask9 = 0xFF, load_mask10 = 0xFF;
583+
opmask_t load_mask11 = 0xFF, load_mask12 = 0xFF;
584+
opmask_t load_mask13 = 0xFF, load_mask14 = 0xFF;
585+
opmask_t load_mask15 = 0xFF, load_mask16 = 0xFF;
586+
if (N != 256) {
587+
uint64_t combined_mask;
588+
if (N < 192) {
589+
combined_mask = (0x1ull << (N - 128)) - 0x1ull;
590+
load_mask1 = (combined_mask) & 0xFF;
591+
load_mask2 = (combined_mask >> 8) & 0xFF;
592+
load_mask3 = (combined_mask >> 16) & 0xFF;
593+
load_mask4 = (combined_mask >> 24) & 0xFF;
594+
load_mask5 = (combined_mask >> 32) & 0xFF;
595+
load_mask6 = (combined_mask >> 40) & 0xFF;
596+
load_mask7 = (combined_mask >> 48) & 0xFF;
597+
load_mask8 = (combined_mask >> 56) & 0xFF;
598+
load_mask9 = 0x00; load_mask10 = 0x0;
599+
load_mask11 = 0x00; load_mask12 = 0x00;
600+
load_mask13 = 0x00; load_mask14 = 0x00;
601+
load_mask15 = 0x00; load_mask16 = 0x00;
602+
}
603+
else {
604+
combined_mask = (0x1ull << (N - 192)) - 0x1ull;
605+
load_mask9 = (combined_mask) & 0xFF;
606+
load_mask10 = (combined_mask >> 8) & 0xFF;
607+
load_mask11 = (combined_mask >> 16) & 0xFF;
608+
load_mask12 = (combined_mask >> 24) & 0xFF;
609+
load_mask13 = (combined_mask >> 32) & 0xFF;
610+
load_mask14 = (combined_mask >> 40) & 0xFF;
611+
load_mask15 = (combined_mask >> 48) & 0xFF;
612+
load_mask16 = (combined_mask >> 56) & 0xFF;
613+
}
614+
}
615+
zmm[16] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 128);
616+
zmm[17] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 136);
617+
zmm[18] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 144);
618+
zmm[19] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 152);
619+
zmm[20] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 160);
620+
zmm[21] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 168);
621+
zmm[22] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 176);
622+
zmm[23] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 184);
623+
if (N < 192) {
624+
zmm[24] = vtype::zmm_max();
625+
zmm[25] = vtype::zmm_max();
626+
zmm[26] = vtype::zmm_max();
627+
zmm[27] = vtype::zmm_max();
628+
zmm[28] = vtype::zmm_max();
629+
zmm[29] = vtype::zmm_max();
630+
zmm[30] = vtype::zmm_max();
631+
zmm[31] = vtype::zmm_max();
632+
}
633+
else {
634+
zmm[24] = vtype::mask_loadu(vtype::zmm_max(), load_mask9, arr + 192);
635+
zmm[25] = vtype::mask_loadu(vtype::zmm_max(), load_mask10, arr + 200);
636+
zmm[26] = vtype::mask_loadu(vtype::zmm_max(), load_mask11, arr + 208);
637+
zmm[27] = vtype::mask_loadu(vtype::zmm_max(), load_mask12, arr + 216);
638+
zmm[28] = vtype::mask_loadu(vtype::zmm_max(), load_mask13, arr + 224);
639+
zmm[29] = vtype::mask_loadu(vtype::zmm_max(), load_mask14, arr + 232);
640+
zmm[30] = vtype::mask_loadu(vtype::zmm_max(), load_mask15, arr + 240);
641+
zmm[31] = vtype::mask_loadu(vtype::zmm_max(), load_mask16, arr + 248);
642+
}
643+
zmm[16] = sort_zmm_64bit<vtype>(zmm[16]);
644+
zmm[17] = sort_zmm_64bit<vtype>(zmm[17]);
645+
zmm[18] = sort_zmm_64bit<vtype>(zmm[18]);
646+
zmm[19] = sort_zmm_64bit<vtype>(zmm[19]);
647+
zmm[20] = sort_zmm_64bit<vtype>(zmm[20]);
648+
zmm[21] = sort_zmm_64bit<vtype>(zmm[21]);
649+
zmm[22] = sort_zmm_64bit<vtype>(zmm[22]);
650+
zmm[23] = sort_zmm_64bit<vtype>(zmm[23]);
651+
zmm[24] = sort_zmm_64bit<vtype>(zmm[24]);
652+
zmm[25] = sort_zmm_64bit<vtype>(zmm[25]);
653+
zmm[26] = sort_zmm_64bit<vtype>(zmm[26]);
654+
zmm[27] = sort_zmm_64bit<vtype>(zmm[27]);
655+
zmm[28] = sort_zmm_64bit<vtype>(zmm[28]);
656+
zmm[29] = sort_zmm_64bit<vtype>(zmm[29]);
657+
zmm[30] = sort_zmm_64bit<vtype>(zmm[30]);
658+
zmm[31] = sort_zmm_64bit<vtype>(zmm[31]);
659+
bitonic_merge_two_zmm_64bit<vtype>(zmm[0], zmm[1]);
660+
bitonic_merge_two_zmm_64bit<vtype>(zmm[2], zmm[3]);
661+
bitonic_merge_two_zmm_64bit<vtype>(zmm[4], zmm[5]);
662+
bitonic_merge_two_zmm_64bit<vtype>(zmm[6], zmm[7]);
663+
bitonic_merge_two_zmm_64bit<vtype>(zmm[8], zmm[9]);
664+
bitonic_merge_two_zmm_64bit<vtype>(zmm[10], zmm[11]);
665+
bitonic_merge_two_zmm_64bit<vtype>(zmm[12], zmm[13]);
666+
bitonic_merge_two_zmm_64bit<vtype>(zmm[14], zmm[15]);
667+
bitonic_merge_two_zmm_64bit<vtype>(zmm[16], zmm[17]);
668+
bitonic_merge_two_zmm_64bit<vtype>(zmm[18], zmm[19]);
669+
bitonic_merge_two_zmm_64bit<vtype>(zmm[20], zmm[21]);
670+
bitonic_merge_two_zmm_64bit<vtype>(zmm[22], zmm[23]);
671+
bitonic_merge_two_zmm_64bit<vtype>(zmm[24], zmm[25]);
672+
bitonic_merge_two_zmm_64bit<vtype>(zmm[26], zmm[27]);
673+
bitonic_merge_two_zmm_64bit<vtype>(zmm[28], zmm[29]);
674+
bitonic_merge_two_zmm_64bit<vtype>(zmm[30], zmm[31]);
675+
bitonic_merge_four_zmm_64bit<vtype>(zmm);
676+
bitonic_merge_four_zmm_64bit<vtype>(zmm + 4);
677+
bitonic_merge_four_zmm_64bit<vtype>(zmm + 8);
678+
bitonic_merge_four_zmm_64bit<vtype>(zmm + 12);
679+
bitonic_merge_four_zmm_64bit<vtype>(zmm + 16);
680+
bitonic_merge_four_zmm_64bit<vtype>(zmm + 20);
681+
bitonic_merge_four_zmm_64bit<vtype>(zmm + 24);
682+
bitonic_merge_four_zmm_64bit<vtype>(zmm + 28);
683+
bitonic_merge_eight_zmm_64bit<vtype>(zmm);
684+
bitonic_merge_eight_zmm_64bit<vtype>(zmm + 8);
685+
bitonic_merge_eight_zmm_64bit<vtype>(zmm + 16);
686+
bitonic_merge_eight_zmm_64bit<vtype>(zmm + 24);
687+
bitonic_merge_sixteen_zmm_64bit<vtype>(zmm);
688+
bitonic_merge_sixteen_zmm_64bit<vtype>(zmm + 16);
689+
bitonic_merge_32_zmm_64bit<vtype>(zmm);
690+
vtype::storeu(arr, zmm[0]);
691+
vtype::storeu(arr + 8, zmm[1]);
692+
vtype::storeu(arr + 16, zmm[2]);
693+
vtype::storeu(arr + 24, zmm[3]);
694+
vtype::storeu(arr + 32, zmm[4]);
695+
vtype::storeu(arr + 40, zmm[5]);
696+
vtype::storeu(arr + 48, zmm[6]);
697+
vtype::storeu(arr + 56, zmm[7]);
698+
vtype::storeu(arr + 64, zmm[8]);
699+
vtype::storeu(arr + 72, zmm[9]);
700+
vtype::storeu(arr + 80, zmm[10]);
701+
vtype::storeu(arr + 88, zmm[11]);
702+
vtype::storeu(arr + 96, zmm[12]);
703+
vtype::storeu(arr + 104, zmm[13]);
704+
vtype::storeu(arr + 112, zmm[14]);
705+
vtype::storeu(arr + 120, zmm[15]);
706+
vtype::mask_storeu(arr + 128, load_mask1, zmm[16]);
707+
vtype::mask_storeu(arr + 136, load_mask2, zmm[17]);
708+
vtype::mask_storeu(arr + 144, load_mask3, zmm[18]);
709+
vtype::mask_storeu(arr + 152, load_mask4, zmm[19]);
710+
vtype::mask_storeu(arr + 160, load_mask5, zmm[20]);
711+
vtype::mask_storeu(arr + 168, load_mask6, zmm[21]);
712+
vtype::mask_storeu(arr + 176, load_mask7, zmm[22]);
713+
vtype::mask_storeu(arr + 184, load_mask8, zmm[23]);
714+
if (N > 192) {
715+
vtype::mask_storeu(arr + 192, load_mask9, zmm[24]);
716+
vtype::mask_storeu(arr + 200, load_mask10, zmm[25]);
717+
vtype::mask_storeu(arr + 208, load_mask11, zmm[26]);
718+
vtype::mask_storeu(arr + 216, load_mask12, zmm[27]);
719+
vtype::mask_storeu(arr + 224, load_mask13, zmm[28]);
720+
vtype::mask_storeu(arr + 232, load_mask14, zmm[29]);
721+
vtype::mask_storeu(arr + 240, load_mask15, zmm[30]);
722+
vtype::mask_storeu(arr + 248, load_mask16, zmm[31]);
723+
}
724+
725+
}
726+
381727
template <typename vtype, typename type_t>
382728
static void
383729
qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
@@ -392,8 +738,8 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
392738
/*
393739
* Base case: use bitonic networks to sort arrays <= 128
394740
*/
395-
if (right + 1 - left <= 128) {
396-
sort_128_64bit<vtype>(arr + left, (int32_t)(right + 1 - left));
741+
if (right + 1 - left <= 256) {
742+
sort_256_64bit<vtype>(arr + left, (int32_t)(right + 1 - left));
397743
return;
398744
}
399745

0 commit comments

Comments
 (0)