@@ -10702,6 +10702,127 @@ enum ggml_opt_result ggml_opt(
10702
10702
10703
10703
////////////////////////////////////////////////////////////////////////////////
10704
10704
10705
+ size_t ggml_quantize_q4_0 (float * src , void * dst , int n , int k , int qk , int64_t * hist ) {
10706
+ const int nb = k / qk ;
10707
+ const size_t bs = (sizeof (float ) + sizeof (uint8_t )* qk /2 );
10708
+ const size_t row_size = nb * bs ;
10709
+
10710
+ assert (k % qk == 0 );
10711
+
10712
+ const size_t pp_size = qk / 2 ;
10713
+ uint8_t * pp = (uint8_t * ) alloca (pp_size );
10714
+
10715
+ char * pdst = (char * ) dst ;
10716
+
10717
+ for (int j = 0 ; j < n ; j += k ) {
10718
+ uint8_t * pd = (uint8_t * ) (pdst + (j /k )* row_size + 0 * bs );
10719
+ uint8_t * pb = (uint8_t * ) (pdst + (j /k )* row_size + 0 * bs + sizeof (float ));
10720
+
10721
+ for (int i = 0 ; i < nb ; i ++ ) {
10722
+ float amax = 0.0f ; // absolute max
10723
+
10724
+ {
10725
+ for (int l = 0 ; l < qk ; l ++ ) {
10726
+ const float v = src [j + i * qk + l ];
10727
+ amax = MAX (amax , fabsf (v ));
10728
+ }
10729
+
10730
+ const float d = amax / ((1 << 3 ) - 1 );
10731
+ const float id = d ? 1.0f /d : 0.0f ;
10732
+
10733
+ * (float * ) pd = d ;
10734
+ pd += bs ;
10735
+
10736
+ for (int l = 0 ; l < qk ; l += 2 ) {
10737
+ const float v0 = (src [j + i * qk + l + 0 ])* id ;
10738
+ const float v1 = (src [j + i * qk + l + 1 ])* id ;
10739
+
10740
+ const uint8_t vi0 = ((int8_t ) (round (v0 ))) + 8 ;
10741
+ const uint8_t vi1 = ((int8_t ) (round (v1 ))) + 8 ;
10742
+
10743
+ assert (vi0 >= 0 && vi0 < 16 );
10744
+ assert (vi1 >= 0 && vi1 < 16 );
10745
+
10746
+ hist [vi0 ]++ ;
10747
+ hist [vi1 ]++ ;
10748
+
10749
+ pp [l /2 ] = vi0 | (vi1 << 4 );
10750
+ }
10751
+
10752
+ memcpy (pb , pp , pp_size );
10753
+ pb += bs ;
10754
+ }
10755
+ }
10756
+ }
10757
+
10758
+ return (n /k )* row_size ;
10759
+ }
10760
+
10761
+ size_t ggml_quantize_q4_1 (float * src , void * dst , int n , int k , int qk , int64_t * hist ) {
10762
+ const int nb = k / qk ;
10763
+ const size_t bs = (2 * sizeof (float ) + sizeof (uint8_t )* qk /2 );
10764
+ const size_t row_size = nb * bs ;
10765
+
10766
+ assert (k % qk == 0 );
10767
+
10768
+ const size_t pp_size = qk / 2 ;
10769
+ uint8_t * pp = (uint8_t * ) alloca (pp_size );
10770
+
10771
+ char * pdst = (char * ) dst ;
10772
+
10773
+ for (int j = 0 ; j < n ; j += k ) {
10774
+ uint8_t * pd = (uint8_t * ) (pdst + (j /k )* row_size + 0 * bs );
10775
+ uint8_t * pm = (uint8_t * ) (pdst + (j /k )* row_size + 0 * bs + sizeof (float ));
10776
+ uint8_t * pb = (uint8_t * ) (pdst + (j /k )* row_size + 0 * bs + 2 * sizeof (float ));
10777
+
10778
+ //printf("n = %d, k = %d, nb = %d, row_size = %d, j = %d, pm = %p, pd = %p, pb = %p\n", n, k, nb, row_size, j, pm, pd, pb);
10779
+
10780
+ for (int i = 0 ; i < nb ; i ++ ) {
10781
+ float min = FLT_MAX ;
10782
+ float max = - FLT_MAX ;
10783
+
10784
+ {
10785
+ for (int l = 0 ; l < qk ; l ++ ) {
10786
+ const float v = src [j + i * qk + l ];
10787
+ if (v < min ) min = v ;
10788
+ if (v > max ) max = v ;
10789
+ }
10790
+
10791
+ const float d = (max - min ) / ((1 << 4 ) - 1 );
10792
+ const float id = d ? 1.0f /d : 0.0f ;
10793
+
10794
+ * (float * ) pd = d ;
10795
+ * (float * ) pm = min ;
10796
+ pd += bs ;
10797
+ pm += bs ;
10798
+
10799
+ for (int l = 0 ; l < qk ; l += 2 ) {
10800
+ const float v0 = (src [j + i * qk + l + 0 ] - min )* id ;
10801
+ const float v1 = (src [j + i * qk + l + 1 ] - min )* id ;
10802
+
10803
+ const uint8_t vi0 = round (v0 );
10804
+ const uint8_t vi1 = round (v1 );
10805
+
10806
+ assert (vi0 >= 0 && vi0 < 16 );
10807
+ assert (vi1 >= 0 && vi1 < 16 );
10808
+
10809
+ hist [vi0 ]++ ;
10810
+ hist [vi1 ]++ ;
10811
+
10812
+ pp [l /2 ] = vi0 | (vi1 << 4 );
10813
+ }
10814
+
10815
+ memcpy (pb , pp , pp_size );
10816
+ pb += bs ;
10817
+ }
10818
+ }
10819
+ }
10820
+
10821
+ return (n /k )* row_size ;
10822
+ }
10823
+
10824
+ ////////////////////////////////////////////////////////////////////////////////
10825
+
10705
10826
int ggml_cpu_has_avx (void ) {
10706
10827
#if defined(__AVX__ )
10707
10828
return 1 ;
0 commit comments