@@ -179,6 +179,161 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm)
179
179
zmm[15 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t16);
180
180
}
181
181
182
+ template <typename vtype, typename zmm_t = typename vtype::zmm_t >
183
+ X86_SIMD_SORT_INLINE void bitonic_merge_32_zmm_64bit (zmm_t *zmm)
184
+ {
185
+ const __m512i rev_index = _mm512_set_epi64 (NETWORK_64BIT_2);
186
+ zmm_t zmm16r = vtype::permutexvar (rev_index, zmm[16 ]);
187
+ zmm_t zmm17r = vtype::permutexvar (rev_index, zmm[17 ]);
188
+ zmm_t zmm18r = vtype::permutexvar (rev_index, zmm[18 ]);
189
+ zmm_t zmm19r = vtype::permutexvar (rev_index, zmm[19 ]);
190
+ zmm_t zmm20r = vtype::permutexvar (rev_index, zmm[20 ]);
191
+ zmm_t zmm21r = vtype::permutexvar (rev_index, zmm[21 ]);
192
+ zmm_t zmm22r = vtype::permutexvar (rev_index, zmm[22 ]);
193
+ zmm_t zmm23r = vtype::permutexvar (rev_index, zmm[23 ]);
194
+ zmm_t zmm24r = vtype::permutexvar (rev_index, zmm[24 ]);
195
+ zmm_t zmm25r = vtype::permutexvar (rev_index, zmm[25 ]);
196
+ zmm_t zmm26r = vtype::permutexvar (rev_index, zmm[26 ]);
197
+ zmm_t zmm27r = vtype::permutexvar (rev_index, zmm[27 ]);
198
+ zmm_t zmm28r = vtype::permutexvar (rev_index, zmm[28 ]);
199
+ zmm_t zmm29r = vtype::permutexvar (rev_index, zmm[29 ]);
200
+ zmm_t zmm30r = vtype::permutexvar (rev_index, zmm[30 ]);
201
+ zmm_t zmm31r = vtype::permutexvar (rev_index, zmm[31 ]);
202
+ zmm_t zmm_t1 = vtype::min (zmm[0 ], zmm31r);
203
+ zmm_t zmm_t2 = vtype::min (zmm[1 ], zmm30r);
204
+ zmm_t zmm_t3 = vtype::min (zmm[2 ], zmm29r);
205
+ zmm_t zmm_t4 = vtype::min (zmm[3 ], zmm28r);
206
+ zmm_t zmm_t5 = vtype::min (zmm[4 ], zmm27r);
207
+ zmm_t zmm_t6 = vtype::min (zmm[5 ], zmm26r);
208
+ zmm_t zmm_t7 = vtype::min (zmm[6 ], zmm25r);
209
+ zmm_t zmm_t8 = vtype::min (zmm[7 ], zmm24r);
210
+ zmm_t zmm_t9 = vtype::min (zmm[8 ], zmm23r);
211
+ zmm_t zmm_t10 = vtype::min (zmm[9 ], zmm22r);
212
+ zmm_t zmm_t11 = vtype::min (zmm[10 ], zmm21r);
213
+ zmm_t zmm_t12 = vtype::min (zmm[11 ], zmm20r);
214
+ zmm_t zmm_t13 = vtype::min (zmm[12 ], zmm19r);
215
+ zmm_t zmm_t14 = vtype::min (zmm[13 ], zmm18r);
216
+ zmm_t zmm_t15 = vtype::min (zmm[14 ], zmm17r);
217
+ zmm_t zmm_t16 = vtype::min (zmm[15 ], zmm16r);
218
+ zmm_t zmm_t17 = vtype::permutexvar (rev_index, vtype::max (zmm[15 ], zmm16r));
219
+ zmm_t zmm_t18 = vtype::permutexvar (rev_index, vtype::max (zmm[14 ], zmm17r));
220
+ zmm_t zmm_t19 = vtype::permutexvar (rev_index, vtype::max (zmm[13 ], zmm18r));
221
+ zmm_t zmm_t20 = vtype::permutexvar (rev_index, vtype::max (zmm[12 ], zmm19r));
222
+ zmm_t zmm_t21 = vtype::permutexvar (rev_index, vtype::max (zmm[11 ], zmm20r));
223
+ zmm_t zmm_t22 = vtype::permutexvar (rev_index, vtype::max (zmm[10 ], zmm21r));
224
+ zmm_t zmm_t23 = vtype::permutexvar (rev_index, vtype::max (zmm[9 ], zmm22r));
225
+ zmm_t zmm_t24 = vtype::permutexvar (rev_index, vtype::max (zmm[8 ], zmm23r));
226
+ zmm_t zmm_t25 = vtype::permutexvar (rev_index, vtype::max (zmm[7 ], zmm24r));
227
+ zmm_t zmm_t26 = vtype::permutexvar (rev_index, vtype::max (zmm[6 ], zmm25r));
228
+ zmm_t zmm_t27 = vtype::permutexvar (rev_index, vtype::max (zmm[5 ], zmm26r));
229
+ zmm_t zmm_t28 = vtype::permutexvar (rev_index, vtype::max (zmm[4 ], zmm27r));
230
+ zmm_t zmm_t29 = vtype::permutexvar (rev_index, vtype::max (zmm[3 ], zmm28r));
231
+ zmm_t zmm_t30 = vtype::permutexvar (rev_index, vtype::max (zmm[2 ], zmm29r));
232
+ zmm_t zmm_t31 = vtype::permutexvar (rev_index, vtype::max (zmm[1 ], zmm30r));
233
+ zmm_t zmm_t32 = vtype::permutexvar (rev_index, vtype::max (zmm[0 ], zmm31r));
234
+ // Recusive half clear 16 zmm regs
235
+ COEX<vtype>(zmm_t1, zmm_t9);
236
+ COEX<vtype>(zmm_t2, zmm_t10);
237
+ COEX<vtype>(zmm_t3, zmm_t11);
238
+ COEX<vtype>(zmm_t4, zmm_t12);
239
+ COEX<vtype>(zmm_t5, zmm_t13);
240
+ COEX<vtype>(zmm_t6, zmm_t14);
241
+ COEX<vtype>(zmm_t7, zmm_t15);
242
+ COEX<vtype>(zmm_t8, zmm_t16);
243
+ COEX<vtype>(zmm_t17, zmm_t25);
244
+ COEX<vtype>(zmm_t18, zmm_t26);
245
+ COEX<vtype>(zmm_t19, zmm_t27);
246
+ COEX<vtype>(zmm_t20, zmm_t28);
247
+ COEX<vtype>(zmm_t21, zmm_t29);
248
+ COEX<vtype>(zmm_t22, zmm_t30);
249
+ COEX<vtype>(zmm_t23, zmm_t31);
250
+ COEX<vtype>(zmm_t24, zmm_t32);
251
+ //
252
+ COEX<vtype>(zmm_t1, zmm_t5);
253
+ COEX<vtype>(zmm_t2, zmm_t6);
254
+ COEX<vtype>(zmm_t3, zmm_t7);
255
+ COEX<vtype>(zmm_t4, zmm_t8);
256
+ COEX<vtype>(zmm_t9, zmm_t13);
257
+ COEX<vtype>(zmm_t10, zmm_t14);
258
+ COEX<vtype>(zmm_t11, zmm_t15);
259
+ COEX<vtype>(zmm_t12, zmm_t16);
260
+ COEX<vtype>(zmm_t17, zmm_t21);
261
+ COEX<vtype>(zmm_t18, zmm_t22);
262
+ COEX<vtype>(zmm_t19, zmm_t23);
263
+ COEX<vtype>(zmm_t20, zmm_t24);
264
+ COEX<vtype>(zmm_t25, zmm_t29);
265
+ COEX<vtype>(zmm_t26, zmm_t30);
266
+ COEX<vtype>(zmm_t27, zmm_t31);
267
+ COEX<vtype>(zmm_t28, zmm_t32);
268
+ //
269
+ COEX<vtype>(zmm_t1, zmm_t3);
270
+ COEX<vtype>(zmm_t2, zmm_t4);
271
+ COEX<vtype>(zmm_t5, zmm_t7);
272
+ COEX<vtype>(zmm_t6, zmm_t8);
273
+ COEX<vtype>(zmm_t9, zmm_t11);
274
+ COEX<vtype>(zmm_t10, zmm_t12);
275
+ COEX<vtype>(zmm_t13, zmm_t15);
276
+ COEX<vtype>(zmm_t14, zmm_t16);
277
+ COEX<vtype>(zmm_t17, zmm_t19);
278
+ COEX<vtype>(zmm_t18, zmm_t20);
279
+ COEX<vtype>(zmm_t21, zmm_t23);
280
+ COEX<vtype>(zmm_t22, zmm_t24);
281
+ COEX<vtype>(zmm_t25, zmm_t27);
282
+ COEX<vtype>(zmm_t26, zmm_t28);
283
+ COEX<vtype>(zmm_t29, zmm_t31);
284
+ COEX<vtype>(zmm_t30, zmm_t32);
285
+ //
286
+ COEX<vtype>(zmm_t1, zmm_t2);
287
+ COEX<vtype>(zmm_t3, zmm_t4);
288
+ COEX<vtype>(zmm_t5, zmm_t6);
289
+ COEX<vtype>(zmm_t7, zmm_t8);
290
+ COEX<vtype>(zmm_t9, zmm_t10);
291
+ COEX<vtype>(zmm_t11, zmm_t12);
292
+ COEX<vtype>(zmm_t13, zmm_t14);
293
+ COEX<vtype>(zmm_t15, zmm_t16);
294
+ COEX<vtype>(zmm_t17, zmm_t18);
295
+ COEX<vtype>(zmm_t19, zmm_t20);
296
+ COEX<vtype>(zmm_t21, zmm_t22);
297
+ COEX<vtype>(zmm_t23, zmm_t24);
298
+ COEX<vtype>(zmm_t25, zmm_t26);
299
+ COEX<vtype>(zmm_t27, zmm_t28);
300
+ COEX<vtype>(zmm_t29, zmm_t30);
301
+ COEX<vtype>(zmm_t31, zmm_t32);
302
+ //
303
+ zmm[0 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t1);
304
+ zmm[1 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t2);
305
+ zmm[2 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t3);
306
+ zmm[3 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t4);
307
+ zmm[4 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t5);
308
+ zmm[5 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t6);
309
+ zmm[6 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t7);
310
+ zmm[7 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t8);
311
+ zmm[8 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t9);
312
+ zmm[9 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t10);
313
+ zmm[10 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t11);
314
+ zmm[11 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t12);
315
+ zmm[12 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t13);
316
+ zmm[13 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t14);
317
+ zmm[14 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t15);
318
+ zmm[15 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t16);
319
+ zmm[16 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t17);
320
+ zmm[17 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t18);
321
+ zmm[18 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t19);
322
+ zmm[19 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t20);
323
+ zmm[20 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t21);
324
+ zmm[21 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t22);
325
+ zmm[22 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t23);
326
+ zmm[23 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t24);
327
+ zmm[24 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t25);
328
+ zmm[25 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t26);
329
+ zmm[26 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t27);
330
+ zmm[27 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t28);
331
+ zmm[28 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t29);
332
+ zmm[29 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t30);
333
+ zmm[30 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t31);
334
+ zmm[31 ] = bitonic_merge_zmm_64bit<vtype>(zmm_t32);
335
+ }
336
+
182
337
template <typename vtype, typename type_t >
183
338
X86_SIMD_SORT_INLINE void sort_8_64bit (type_t *arr, int32_t N)
184
339
{
@@ -378,6 +533,197 @@ X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N)
378
533
vtype::mask_storeu (arr + 120 , load_mask8, zmm[15 ]);
379
534
}
380
535
536
+ template <typename vtype, typename type_t >
537
+ X86_SIMD_SORT_INLINE void sort_256_64bit (type_t *arr, int32_t N)
538
+ {
539
+ if (N <= 128 ) {
540
+ sort_128_64bit<vtype>(arr, N);
541
+ return ;
542
+ }
543
+ using zmm_t = typename vtype::zmm_t ;
544
+ using opmask_t = typename vtype::opmask_t ;
545
+ zmm_t zmm[32 ];
546
+ zmm[0 ] = vtype::loadu (arr);
547
+ zmm[1 ] = vtype::loadu (arr + 8 );
548
+ zmm[2 ] = vtype::loadu (arr + 16 );
549
+ zmm[3 ] = vtype::loadu (arr + 24 );
550
+ zmm[4 ] = vtype::loadu (arr + 32 );
551
+ zmm[5 ] = vtype::loadu (arr + 40 );
552
+ zmm[6 ] = vtype::loadu (arr + 48 );
553
+ zmm[7 ] = vtype::loadu (arr + 56 );
554
+ zmm[8 ] = vtype::loadu (arr + 64 );
555
+ zmm[9 ] = vtype::loadu (arr + 72 );
556
+ zmm[10 ] = vtype::loadu (arr + 80 );
557
+ zmm[11 ] = vtype::loadu (arr + 88 );
558
+ zmm[12 ] = vtype::loadu (arr + 96 );
559
+ zmm[13 ] = vtype::loadu (arr + 104 );
560
+ zmm[14 ] = vtype::loadu (arr + 112 );
561
+ zmm[15 ] = vtype::loadu (arr + 120 );
562
+ zmm[0 ] = sort_zmm_64bit<vtype>(zmm[0 ]);
563
+ zmm[1 ] = sort_zmm_64bit<vtype>(zmm[1 ]);
564
+ zmm[2 ] = sort_zmm_64bit<vtype>(zmm[2 ]);
565
+ zmm[3 ] = sort_zmm_64bit<vtype>(zmm[3 ]);
566
+ zmm[4 ] = sort_zmm_64bit<vtype>(zmm[4 ]);
567
+ zmm[5 ] = sort_zmm_64bit<vtype>(zmm[5 ]);
568
+ zmm[6 ] = sort_zmm_64bit<vtype>(zmm[6 ]);
569
+ zmm[7 ] = sort_zmm_64bit<vtype>(zmm[7 ]);
570
+ zmm[8 ] = sort_zmm_64bit<vtype>(zmm[8 ]);
571
+ zmm[9 ] = sort_zmm_64bit<vtype>(zmm[9 ]);
572
+ zmm[10 ] = sort_zmm_64bit<vtype>(zmm[10 ]);
573
+ zmm[11 ] = sort_zmm_64bit<vtype>(zmm[11 ]);
574
+ zmm[12 ] = sort_zmm_64bit<vtype>(zmm[12 ]);
575
+ zmm[13 ] = sort_zmm_64bit<vtype>(zmm[13 ]);
576
+ zmm[14 ] = sort_zmm_64bit<vtype>(zmm[14 ]);
577
+ zmm[15 ] = sort_zmm_64bit<vtype>(zmm[15 ]);
578
+ opmask_t load_mask1 = 0xFF , load_mask2 = 0xFF ;
579
+ opmask_t load_mask3 = 0xFF , load_mask4 = 0xFF ;
580
+ opmask_t load_mask5 = 0xFF , load_mask6 = 0xFF ;
581
+ opmask_t load_mask7 = 0xFF , load_mask8 = 0xFF ;
582
+ opmask_t load_mask9 = 0xFF , load_mask10 = 0xFF ;
583
+ opmask_t load_mask11 = 0xFF , load_mask12 = 0xFF ;
584
+ opmask_t load_mask13 = 0xFF , load_mask14 = 0xFF ;
585
+ opmask_t load_mask15 = 0xFF , load_mask16 = 0xFF ;
586
+ if (N != 256 ) {
587
+ uint64_t combined_mask;
588
+ if (N < 192 ) {
589
+ combined_mask = (0x1ull << (N - 128 )) - 0x1ull ;
590
+ load_mask1 = (combined_mask) & 0xFF ;
591
+ load_mask2 = (combined_mask >> 8 ) & 0xFF ;
592
+ load_mask3 = (combined_mask >> 16 ) & 0xFF ;
593
+ load_mask4 = (combined_mask >> 24 ) & 0xFF ;
594
+ load_mask5 = (combined_mask >> 32 ) & 0xFF ;
595
+ load_mask6 = (combined_mask >> 40 ) & 0xFF ;
596
+ load_mask7 = (combined_mask >> 48 ) & 0xFF ;
597
+ load_mask8 = (combined_mask >> 56 ) & 0xFF ;
598
+ load_mask9 = 0x00 ; load_mask10 = 0x0 ;
599
+ load_mask11 = 0x00 ; load_mask12 = 0x00 ;
600
+ load_mask13 = 0x00 ; load_mask14 = 0x00 ;
601
+ load_mask15 = 0x00 ; load_mask16 = 0x00 ;
602
+ }
603
+ else {
604
+ combined_mask = (0x1ull << (N - 192 )) - 0x1ull ;
605
+ load_mask9 = (combined_mask) & 0xFF ;
606
+ load_mask10 = (combined_mask >> 8 ) & 0xFF ;
607
+ load_mask11 = (combined_mask >> 16 ) & 0xFF ;
608
+ load_mask12 = (combined_mask >> 24 ) & 0xFF ;
609
+ load_mask13 = (combined_mask >> 32 ) & 0xFF ;
610
+ load_mask14 = (combined_mask >> 40 ) & 0xFF ;
611
+ load_mask15 = (combined_mask >> 48 ) & 0xFF ;
612
+ load_mask16 = (combined_mask >> 56 ) & 0xFF ;
613
+ }
614
+ }
615
+ zmm[16 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask1, arr + 128 );
616
+ zmm[17 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask2, arr + 136 );
617
+ zmm[18 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask3, arr + 144 );
618
+ zmm[19 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask4, arr + 152 );
619
+ zmm[20 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask5, arr + 160 );
620
+ zmm[21 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask6, arr + 168 );
621
+ zmm[22 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask7, arr + 176 );
622
+ zmm[23 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask8, arr + 184 );
623
+ if (N < 192 ) {
624
+ zmm[24 ] = vtype::zmm_max ();
625
+ zmm[25 ] = vtype::zmm_max ();
626
+ zmm[26 ] = vtype::zmm_max ();
627
+ zmm[27 ] = vtype::zmm_max ();
628
+ zmm[28 ] = vtype::zmm_max ();
629
+ zmm[29 ] = vtype::zmm_max ();
630
+ zmm[30 ] = vtype::zmm_max ();
631
+ zmm[31 ] = vtype::zmm_max ();
632
+ }
633
+ else {
634
+ zmm[24 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask9, arr + 192 );
635
+ zmm[25 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask10, arr + 200 );
636
+ zmm[26 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask11, arr + 208 );
637
+ zmm[27 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask12, arr + 216 );
638
+ zmm[28 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask13, arr + 224 );
639
+ zmm[29 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask14, arr + 232 );
640
+ zmm[30 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask15, arr + 240 );
641
+ zmm[31 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask16, arr + 248 );
642
+ }
643
+ zmm[16 ] = sort_zmm_64bit<vtype>(zmm[16 ]);
644
+ zmm[17 ] = sort_zmm_64bit<vtype>(zmm[17 ]);
645
+ zmm[18 ] = sort_zmm_64bit<vtype>(zmm[18 ]);
646
+ zmm[19 ] = sort_zmm_64bit<vtype>(zmm[19 ]);
647
+ zmm[20 ] = sort_zmm_64bit<vtype>(zmm[20 ]);
648
+ zmm[21 ] = sort_zmm_64bit<vtype>(zmm[21 ]);
649
+ zmm[22 ] = sort_zmm_64bit<vtype>(zmm[22 ]);
650
+ zmm[23 ] = sort_zmm_64bit<vtype>(zmm[23 ]);
651
+ zmm[24 ] = sort_zmm_64bit<vtype>(zmm[24 ]);
652
+ zmm[25 ] = sort_zmm_64bit<vtype>(zmm[25 ]);
653
+ zmm[26 ] = sort_zmm_64bit<vtype>(zmm[26 ]);
654
+ zmm[27 ] = sort_zmm_64bit<vtype>(zmm[27 ]);
655
+ zmm[28 ] = sort_zmm_64bit<vtype>(zmm[28 ]);
656
+ zmm[29 ] = sort_zmm_64bit<vtype>(zmm[29 ]);
657
+ zmm[30 ] = sort_zmm_64bit<vtype>(zmm[30 ]);
658
+ zmm[31 ] = sort_zmm_64bit<vtype>(zmm[31 ]);
659
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[0 ], zmm[1 ]);
660
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[2 ], zmm[3 ]);
661
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[4 ], zmm[5 ]);
662
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[6 ], zmm[7 ]);
663
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[8 ], zmm[9 ]);
664
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[10 ], zmm[11 ]);
665
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[12 ], zmm[13 ]);
666
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[14 ], zmm[15 ]);
667
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[16 ], zmm[17 ]);
668
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[18 ], zmm[19 ]);
669
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[20 ], zmm[21 ]);
670
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[22 ], zmm[23 ]);
671
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[24 ], zmm[25 ]);
672
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[26 ], zmm[27 ]);
673
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[28 ], zmm[29 ]);
674
+ bitonic_merge_two_zmm_64bit<vtype>(zmm[30 ], zmm[31 ]);
675
+ bitonic_merge_four_zmm_64bit<vtype>(zmm);
676
+ bitonic_merge_four_zmm_64bit<vtype>(zmm + 4 );
677
+ bitonic_merge_four_zmm_64bit<vtype>(zmm + 8 );
678
+ bitonic_merge_four_zmm_64bit<vtype>(zmm + 12 );
679
+ bitonic_merge_four_zmm_64bit<vtype>(zmm + 16 );
680
+ bitonic_merge_four_zmm_64bit<vtype>(zmm + 20 );
681
+ bitonic_merge_four_zmm_64bit<vtype>(zmm + 24 );
682
+ bitonic_merge_four_zmm_64bit<vtype>(zmm + 28 );
683
+ bitonic_merge_eight_zmm_64bit<vtype>(zmm);
684
+ bitonic_merge_eight_zmm_64bit<vtype>(zmm + 8 );
685
+ bitonic_merge_eight_zmm_64bit<vtype>(zmm + 16 );
686
+ bitonic_merge_eight_zmm_64bit<vtype>(zmm + 24 );
687
+ bitonic_merge_sixteen_zmm_64bit<vtype>(zmm);
688
+ bitonic_merge_sixteen_zmm_64bit<vtype>(zmm + 16 );
689
+ bitonic_merge_32_zmm_64bit<vtype>(zmm);
690
+ vtype::storeu (arr, zmm[0 ]);
691
+ vtype::storeu (arr + 8 , zmm[1 ]);
692
+ vtype::storeu (arr + 16 , zmm[2 ]);
693
+ vtype::storeu (arr + 24 , zmm[3 ]);
694
+ vtype::storeu (arr + 32 , zmm[4 ]);
695
+ vtype::storeu (arr + 40 , zmm[5 ]);
696
+ vtype::storeu (arr + 48 , zmm[6 ]);
697
+ vtype::storeu (arr + 56 , zmm[7 ]);
698
+ vtype::storeu (arr + 64 , zmm[8 ]);
699
+ vtype::storeu (arr + 72 , zmm[9 ]);
700
+ vtype::storeu (arr + 80 , zmm[10 ]);
701
+ vtype::storeu (arr + 88 , zmm[11 ]);
702
+ vtype::storeu (arr + 96 , zmm[12 ]);
703
+ vtype::storeu (arr + 104 , zmm[13 ]);
704
+ vtype::storeu (arr + 112 , zmm[14 ]);
705
+ vtype::storeu (arr + 120 , zmm[15 ]);
706
+ vtype::mask_storeu (arr + 128 , load_mask1, zmm[16 ]);
707
+ vtype::mask_storeu (arr + 136 , load_mask2, zmm[17 ]);
708
+ vtype::mask_storeu (arr + 144 , load_mask3, zmm[18 ]);
709
+ vtype::mask_storeu (arr + 152 , load_mask4, zmm[19 ]);
710
+ vtype::mask_storeu (arr + 160 , load_mask5, zmm[20 ]);
711
+ vtype::mask_storeu (arr + 168 , load_mask6, zmm[21 ]);
712
+ vtype::mask_storeu (arr + 176 , load_mask7, zmm[22 ]);
713
+ vtype::mask_storeu (arr + 184 , load_mask8, zmm[23 ]);
714
+ if (N > 192 ) {
715
+ vtype::mask_storeu (arr + 192 , load_mask9, zmm[24 ]);
716
+ vtype::mask_storeu (arr + 200 , load_mask10, zmm[25 ]);
717
+ vtype::mask_storeu (arr + 208 , load_mask11, zmm[26 ]);
718
+ vtype::mask_storeu (arr + 216 , load_mask12, zmm[27 ]);
719
+ vtype::mask_storeu (arr + 224 , load_mask13, zmm[28 ]);
720
+ vtype::mask_storeu (arr + 232 , load_mask14, zmm[29 ]);
721
+ vtype::mask_storeu (arr + 240 , load_mask15, zmm[30 ]);
722
+ vtype::mask_storeu (arr + 248 , load_mask16, zmm[31 ]);
723
+ }
724
+
725
+ }
726
+
381
727
template <typename vtype, typename type_t >
382
728
static void
383
729
qsort_64bit_ (type_t *arr, int64_t left, int64_t right, int64_t max_iters)
@@ -392,8 +738,8 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
392
738
/*
393
739
* Base case: use bitonic networks to sort arrays <= 128
394
740
*/
395
- if (right + 1 - left <= 128 ) {
396
- sort_128_64bit <vtype>(arr + left, (int32_t )(right + 1 - left));
741
+ if (right + 1 - left <= 256 ) {
742
+ sort_256_64bit <vtype>(arr + left, (int32_t )(right + 1 - left));
397
743
return ;
398
744
}
399
745
0 commit comments