16
16
17
17
package io .github .jbellis .jvector .quantization ;
18
18
19
- import io .github .jbellis .jvector .annotations .VisibleForTesting ;
20
19
import io .github .jbellis .jvector .disk .RandomAccessReader ;
21
20
import io .github .jbellis .jvector .graph .RandomAccessVectorValues ;
22
21
import io .github .jbellis .jvector .graph .similarity .ScoreFunction ;
37
36
38
37
public abstract class PQVectors implements CompressedVectors {
39
38
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider .getInstance ().getVectorTypeSupport ();
40
- static final int MAX_CHUNK_SIZE = Integer .MAX_VALUE - 16 ; // standard Java array size limit with some headroom
41
39
42
40
final ProductQuantization pq ;
43
41
protected ByteSequence <?>[] compressedDataChunks ;
@@ -55,51 +53,19 @@ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException
55
53
int vectorCount = in .readInt ();
56
54
int compressedDimension = in .readInt ();
57
55
58
- int [] params = calculateChunkParameters (vectorCount , compressedDimension );
59
- int vectorsPerChunk = params [0 ];
60
- int totalChunks = params [1 ];
61
- int fullSizeChunks = params [2 ];
62
- int remainingVectors = params [3 ];
56
+ PQLayout layout = new PQLayout (vectorCount ,compressedDimension );
57
+ ByteSequence <?>[] chunks = new ByteSequence <?>[layout .totalChunks ];
63
58
64
- ByteSequence <?>[] chunks = new ByteSequence <?>[totalChunks ];
65
- int chunkBytes = vectorsPerChunk * compressedDimension ;
66
-
67
- for (int i = 0 ; i < fullSizeChunks ; i ++) {
68
- chunks [i ] = vectorTypeSupport .readByteSequence (in , chunkBytes );
59
+ for (int i = 0 ; i < layout .fullSizeChunks ; i ++) {
60
+ chunks [i ] = vectorTypeSupport .readByteSequence (in , layout .fullChunkBytes );
69
61
}
70
62
71
63
// Last chunk might be smaller
72
- if (totalChunks > fullSizeChunks ) {
73
- chunks [fullSizeChunks ] = vectorTypeSupport .readByteSequence (in , remainingVectors * compressedDimension );
64
+ if (layout . totalChunks > layout . fullSizeChunks ) {
65
+ chunks [layout . fullSizeChunks ] = vectorTypeSupport .readByteSequence (in , layout . lastChunkBytes );
74
66
}
75
67
76
- return new ImmutablePQVectors (pq , chunks , vectorCount , vectorsPerChunk );
77
- }
78
-
79
- /**
80
- * Calculate chunking parameters for the given vector count and compressed dimension
81
- * @return array of [vectorsPerChunk, totalChunks, fullSizeChunks, remainingVectors]
82
- */
83
- @ VisibleForTesting
84
- static int [] calculateChunkParameters (int vectorCount , int compressedDimension ) {
85
- if (vectorCount < 0 ) {
86
- throw new IllegalArgumentException ("Invalid vector count " + vectorCount );
87
- }
88
- if (compressedDimension < 0 ) {
89
- throw new IllegalArgumentException ("Invalid compressed dimension " + compressedDimension );
90
- }
91
-
92
- long totalSize = (long ) vectorCount * compressedDimension ;
93
- int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension ;
94
- if (vectorsPerChunk == 0 ) {
95
- throw new IllegalArgumentException ("Compressed dimension " + compressedDimension + " too large for chunking" );
96
- }
97
-
98
- int fullSizeChunks = vectorCount / vectorsPerChunk ;
99
- int totalChunks = vectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1 ;
100
-
101
- int remainingVectors = vectorCount % vectorsPerChunk ;
102
- return new int [] {vectorsPerChunk , totalChunks , fullSizeChunks , remainingVectors };
68
+ return new ImmutablePQVectors (pq , chunks , vectorCount , layout .fullChunkVectors );
103
69
}
104
70
105
71
public static PQVectors load (RandomAccessReader in , long offset ) throws IOException {
@@ -118,20 +84,15 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept
118
84
* @return the PQVectors instance
119
85
*/
120
86
public static ImmutablePQVectors encodeAndBuild (ProductQuantization pq , int vectorCount , RandomAccessVectorValues ravv , ForkJoinPool simdExecutor ) {
121
- // Calculate if we need to split into multiple chunks
122
87
int compressedDimension = pq .compressedVectorSize ();
123
- long totalSize = (long ) vectorCount * compressedDimension ;
124
- int vectorsPerChunk = totalSize <= PQVectors .MAX_CHUNK_SIZE ? vectorCount : PQVectors .MAX_CHUNK_SIZE / compressedDimension ;
125
-
126
- int numChunks = vectorCount / vectorsPerChunk ;
127
- final ByteSequence <?>[] chunks = new ByteSequence <?>[numChunks ];
128
- int chunkSize = vectorsPerChunk * compressedDimension ;
129
- for (int i = 0 ; i < numChunks - 1 ; i ++)
130
- chunks [i ] = vectorTypeSupport .createByteSequence (chunkSize );
131
-
132
- // Last chunk might be smaller
133
- int remainingVectors = vectorCount - (vectorsPerChunk * (numChunks - 1 ));
134
- chunks [numChunks - 1 ] = vectorTypeSupport .createByteSequence (remainingVectors * compressedDimension );
88
+ PQLayout layout = new PQLayout (vectorCount ,compressedDimension );
89
+ final ByteSequence <?>[] chunks = new ByteSequence <?>[layout .totalChunks ];
90
+ for (int i = 0 ; i < layout .fullSizeChunks ; i ++) {
91
+ chunks [i ] = vectorTypeSupport .createByteSequence (layout .fullChunkBytes );
92
+ }
93
+ if (layout .lastChunkVectors > 0 ) {
94
+ chunks [layout .fullSizeChunks ] = vectorTypeSupport .createByteSequence (layout .lastChunkBytes );
95
+ }
135
96
136
97
// Encode the vectors in parallel into the compressed data chunks
137
98
// The changes are concurrent, but because they are coordinated and do not overlap, we can use parallel streams
@@ -142,7 +103,7 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect
142
103
.forEach (ordinal -> {
143
104
// Retrieve the slice and mutate it.
144
105
var localRavv = ravvCopy .get ();
145
- var slice = PQVectors .get (chunks , ordinal , vectorsPerChunk , pq .getSubspaceCount ());
106
+ var slice = PQVectors .get (chunks , ordinal , layout . fullChunkVectors , pq .getSubspaceCount ());
146
107
var vector = localRavv .getVector (ordinal );
147
108
if (vector != null )
148
109
pq .encodeTo (vector , slice );
@@ -151,7 +112,7 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect
151
112
}))
152
113
.join ();
153
114
154
- return new ImmutablePQVectors (pq , chunks , vectorCount , vectorsPerChunk );
115
+ return new ImmutablePQVectors (pq , chunks , vectorCount , layout . fullChunkVectors );
155
116
}
156
117
157
118
@ Override
@@ -443,4 +404,73 @@ public String toString() {
443
404
", count=" + count () +
444
405
'}' ;
445
406
}
407
+
408
+ /**
409
+ * Chunk Dimensions and Layout
410
+ * This is emulative of modern Java records, but keeps to J11 standards.
411
+ * This class consolidates the layout calculations for PQ data into one place
412
+ */
413
+ static class PQLayout {
414
+
415
+ /**
416
+ * total number of vectors
417
+ **/
418
+ public final int vectorCount ;
419
+ /**
420
+ * total number of chunks, including any partial
421
+ **/
422
+ public final int totalChunks ;
423
+ /**
424
+ * total number of fully-filled chunks
425
+ **/
426
+ public final int fullSizeChunks ;
427
+ /**
428
+ * number of vectors per fullSize chunk
429
+ **/
430
+ public final int fullChunkVectors ;
431
+ /**
432
+ * number of vectors in last partially filled chunk, if any
433
+ **/
434
+ public final int lastChunkVectors ;
435
+ /**
436
+ * compressed dimension of vectors
437
+ **/
438
+ public final int compressedDimension ;
439
+ /**
440
+ * number of bytes in each fully-filled chunk
441
+ **/
442
+ public final int fullChunkBytes ;
443
+ /**
444
+ * number of bytes in the last partially-filled chunk, if any
445
+ **/
446
+ public final int lastChunkBytes ;
447
+
448
+ public PQLayout (int vectorCount , int compressedDimension ) {
449
+ if (vectorCount <= 0 ) {
450
+ throw new IllegalArgumentException ("Invalid vector count " + vectorCount );
451
+ }
452
+ this .vectorCount = vectorCount ;
453
+
454
+ if (compressedDimension <= 0 ) {
455
+ throw new IllegalArgumentException ("Invalid compressed dimension " + compressedDimension );
456
+ }
457
+ this .compressedDimension = compressedDimension ;
458
+
459
+ // Get the aligned number of bytes needed to hold a given dimension
460
+ // purely for overflow prevention
461
+ int layoutBytesPerVector = compressedDimension == 1 ? 1 : Integer .highestOneBit (compressedDimension - 1 ) << 1 ;
462
+ // truncation welcome here, biasing for smaller chunks
463
+ int addressableVectorsPerChunk = Integer .MAX_VALUE / layoutBytesPerVector ;
464
+
465
+ fullChunkVectors = Math .min (vectorCount , addressableVectorsPerChunk );
466
+ lastChunkVectors = vectorCount % fullChunkVectors ;
467
+
468
+ fullChunkBytes = fullChunkVectors * compressedDimension ;
469
+ lastChunkBytes = lastChunkVectors * compressedDimension ;
470
+
471
+ fullSizeChunks = vectorCount / fullChunkVectors ;
472
+ totalChunks = fullSizeChunks + (lastChunkVectors == 0 ? 0 : 1 );
473
+ }
474
+
475
+ }
446
476
}
0 commit comments