Skip to content

Commit a916a07

Browse files
authored
PQ ranging bugfix and refactoring (#508)
* PQ ranging bugfix and refactoring * cover original bug with unit test * use DRY fields * remove commented code and misplaced comment * refinements around ranging fix * variable update * add example data * make PQLayout package private * move PQLayout to PQVector
1 parent 6639992 commit a916a07

File tree

2 files changed

+222
-89
lines changed

2 files changed

+222
-89
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java

Lines changed: 86 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package io.github.jbellis.jvector.quantization;
1818

19-
import io.github.jbellis.jvector.annotations.VisibleForTesting;
2019
import io.github.jbellis.jvector.disk.RandomAccessReader;
2120
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
2221
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
@@ -37,7 +36,6 @@
3736

3837
public abstract class PQVectors implements CompressedVectors {
3938
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
40-
static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom
4139

4240
final ProductQuantization pq;
4341
protected ByteSequence<?>[] compressedDataChunks;
@@ -55,51 +53,19 @@ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException
5553
int vectorCount = in.readInt();
5654
int compressedDimension = in.readInt();
5755

58-
int[] params = calculateChunkParameters(vectorCount, compressedDimension);
59-
int vectorsPerChunk = params[0];
60-
int totalChunks = params[1];
61-
int fullSizeChunks = params[2];
62-
int remainingVectors = params[3];
56+
PQLayout layout = new PQLayout(vectorCount,compressedDimension);
57+
ByteSequence<?>[] chunks = new ByteSequence<?>[layout.totalChunks];
6358

64-
ByteSequence<?>[] chunks = new ByteSequence<?>[totalChunks];
65-
int chunkBytes = vectorsPerChunk * compressedDimension;
66-
67-
for (int i = 0; i < fullSizeChunks; i++) {
68-
chunks[i] = vectorTypeSupport.readByteSequence(in, chunkBytes);
59+
for (int i = 0; i < layout.fullSizeChunks; i++) {
60+
chunks[i] = vectorTypeSupport.readByteSequence(in, layout.fullChunkBytes);
6961
}
7062

7163
// Last chunk might be smaller
72-
if (totalChunks > fullSizeChunks) {
73-
chunks[fullSizeChunks] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension);
64+
if (layout.totalChunks > layout.fullSizeChunks) {
65+
chunks[layout.fullSizeChunks] = vectorTypeSupport.readByteSequence(in, layout.lastChunkBytes);
7466
}
7567

76-
return new ImmutablePQVectors(pq, chunks, vectorCount, vectorsPerChunk);
77-
}
78-
79-
/**
80-
* Calculate chunking parameters for the given vector count and compressed dimension
81-
* @return array of [vectorsPerChunk, totalChunks, fullSizeChunks, remainingVectors]
82-
*/
83-
@VisibleForTesting
84-
static int[] calculateChunkParameters(int vectorCount, int compressedDimension) {
85-
if (vectorCount < 0) {
86-
throw new IllegalArgumentException("Invalid vector count " + vectorCount);
87-
}
88-
if (compressedDimension < 0) {
89-
throw new IllegalArgumentException("Invalid compressed dimension " + compressedDimension);
90-
}
91-
92-
long totalSize = (long) vectorCount * compressedDimension;
93-
int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension;
94-
if (vectorsPerChunk == 0) {
95-
throw new IllegalArgumentException("Compressed dimension " + compressedDimension + " too large for chunking");
96-
}
97-
98-
int fullSizeChunks = vectorCount / vectorsPerChunk;
99-
int totalChunks = vectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1;
100-
101-
int remainingVectors = vectorCount % vectorsPerChunk;
102-
return new int[] {vectorsPerChunk, totalChunks, fullSizeChunks, remainingVectors};
68+
return new ImmutablePQVectors(pq, chunks, vectorCount, layout.fullChunkVectors);
10369
}
10470

10571
public static PQVectors load(RandomAccessReader in, long offset) throws IOException {
@@ -118,20 +84,15 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept
11884
* @return the PQVectors instance
11985
*/
12086
public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vectorCount, RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) {
121-
// Calculate if we need to split into multiple chunks
12287
int compressedDimension = pq.compressedVectorSize();
123-
long totalSize = (long) vectorCount * compressedDimension;
124-
int vectorsPerChunk = totalSize <= PQVectors.MAX_CHUNK_SIZE ? vectorCount : PQVectors.MAX_CHUNK_SIZE / compressedDimension;
125-
126-
int numChunks = vectorCount / vectorsPerChunk;
127-
final ByteSequence<?>[] chunks = new ByteSequence<?>[numChunks];
128-
int chunkSize = vectorsPerChunk * compressedDimension;
129-
for (int i = 0; i < numChunks - 1; i++)
130-
chunks[i] = vectorTypeSupport.createByteSequence(chunkSize);
131-
132-
// Last chunk might be smaller
133-
int remainingVectors = vectorCount - (vectorsPerChunk * (numChunks - 1));
134-
chunks[numChunks - 1] = vectorTypeSupport.createByteSequence(remainingVectors * compressedDimension);
88+
PQLayout layout = new PQLayout(vectorCount,compressedDimension);
89+
final ByteSequence<?>[] chunks = new ByteSequence<?>[layout.totalChunks];
90+
for (int i = 0; i < layout.fullSizeChunks; i++) {
91+
chunks[i] = vectorTypeSupport.createByteSequence(layout.fullChunkBytes);
92+
}
93+
if (layout.lastChunkVectors > 0) {
94+
chunks[layout.fullSizeChunks] = vectorTypeSupport.createByteSequence(layout.lastChunkBytes);
95+
}
13596

13697
// Encode the vectors in parallel into the compressed data chunks
13798
// The changes are concurrent, but because they are coordinated and do not overlap, we can use parallel streams
@@ -142,7 +103,7 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect
142103
.forEach(ordinal -> {
143104
// Retrieve the slice and mutate it.
144105
var localRavv = ravvCopy.get();
145-
var slice = PQVectors.get(chunks, ordinal, vectorsPerChunk, pq.getSubspaceCount());
106+
var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount());
146107
var vector = localRavv.getVector(ordinal);
147108
if (vector != null)
148109
pq.encodeTo(vector, slice);
@@ -151,7 +112,7 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect
151112
}))
152113
.join();
153114

154-
return new ImmutablePQVectors(pq, chunks, vectorCount, vectorsPerChunk);
115+
return new ImmutablePQVectors(pq, chunks, vectorCount, layout.fullChunkVectors);
155116
}
156117

157118
@Override
@@ -443,4 +404,73 @@ public String toString() {
443404
", count=" + count() +
444405
'}';
445406
}
407+
408+
/**
409+
* Chunk Dimensions and Layout
410+
* This is emulative of modern Java records, but keeps to J11 standards.
411+
* This class consolidates the layout calculations for PQ data into one place
412+
*/
413+
static class PQLayout {
414+
415+
/**
416+
* total number of vectors
417+
**/
418+
public final int vectorCount;
419+
/**
420+
* total number of chunks, including any partial
421+
**/
422+
public final int totalChunks;
423+
/**
424+
* total number of fully-filled chunks
425+
**/
426+
public final int fullSizeChunks;
427+
/**
428+
* number of vectors per fullSize chunk
429+
**/
430+
public final int fullChunkVectors;
431+
/**
432+
* number of vectors in last partially filled chunk, if any
433+
**/
434+
public final int lastChunkVectors;
435+
/**
436+
* compressed dimension of vectors
437+
**/
438+
public final int compressedDimension;
439+
/**
440+
* number of bytes in each fully-filled chunk
441+
**/
442+
public final int fullChunkBytes;
443+
/**
444+
* number of bytes in the last partially-filled chunk, if any
445+
**/
446+
public final int lastChunkBytes;
447+
448+
public PQLayout(int vectorCount, int compressedDimension) {
449+
if (vectorCount <= 0) {
450+
throw new IllegalArgumentException("Invalid vector count " + vectorCount);
451+
}
452+
this.vectorCount = vectorCount;
453+
454+
if (compressedDimension <= 0) {
455+
throw new IllegalArgumentException("Invalid compressed dimension " + compressedDimension);
456+
}
457+
this.compressedDimension = compressedDimension;
458+
459+
// Get the aligned number of bytes needed to hold a given dimension
460+
// purely for overflow prevention
461+
int layoutBytesPerVector = compressedDimension == 1 ? 1 : Integer.highestOneBit(compressedDimension - 1) << 1;
462+
// truncation welcome here, biasing for smaller chunks
463+
int addressableVectorsPerChunk = Integer.MAX_VALUE / layoutBytesPerVector;
464+
465+
fullChunkVectors = Math.min(vectorCount, addressableVectorsPerChunk);
466+
lastChunkVectors = vectorCount % fullChunkVectors;
467+
468+
fullChunkBytes = fullChunkVectors * compressedDimension;
469+
lastChunkBytes = lastChunkVectors * compressedDimension;
470+
471+
fullSizeChunks = vectorCount / fullChunkVectors;
472+
totalChunks = fullSizeChunks + (lastChunkVectors == 0 ? 0 : 1);
473+
}
474+
475+
}
446476
}

0 commit comments

Comments
 (0)