Skip to content

Commit 1c29821

Browse files
authored
Create partial sums for PQ codebook for use during diversity checks (#511)
* Create partial sums for PQ codebook for use during diversity checks of graph building Signed-off-by: Jake Luciani <[email protected]>
1 parent a916a07 commit 1c29821

File tree

9 files changed

+395
-3
lines changed

9 files changed

+395
-3
lines changed

benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithRandomSetBenchmark.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
@BenchmarkMode(Mode.AverageTime)
4242
@OutputTimeUnit(TimeUnit.MILLISECONDS)
4343
@State(Scope.Thread)
44-
@Fork(1)
44+
@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=false"})
4545
@Warmup(iterations = 2)
4646
@Measurement(iterations = 3)
4747
@Threads(1)
@@ -59,7 +59,7 @@ public class IndexConstructionWithRandomSetBenchmark {
5959
@Param({"0", "16"})
6060
private int numberOfPQSubspaces;
6161

62-
@Setup(Level.Invocation)
62+
@Setup(Level.Trial)
6363
public void setup() throws IOException {
6464

6565
final var baseVectors = new ArrayList<VectorFloat<?>>(numBaseVectors);

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,18 @@
1616

1717
package io.github.jbellis.jvector.quantization;
1818

19+
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
20+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
21+
import io.github.jbellis.jvector.vector.VectorUtil;
1922
import io.github.jbellis.jvector.vector.types.ByteSequence;
23+
import io.github.jbellis.jvector.vector.types.VectorFloat;
24+
25+
import java.util.HashMap;
26+
import java.util.Map;
2027

2128
public class ImmutablePQVectors extends PQVectors {
2229
private final int vectorCount;
30+
private final Map<VectorSimilarityFunction, VectorFloat<?>> codebookPartialSumsMap;
2331

2432
/**
2533
* Construct an immutable PQVectors instance with the given ProductQuantization and compressed data chunks.
@@ -33,6 +41,7 @@ public ImmutablePQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDa
3341
this.compressedDataChunks = compressedDataChunks;
3442
this.vectorCount = vectorCount;
3543
this.vectorsPerChunk = vectorsPerChunk;
44+
this.codebookPartialSumsMap = new HashMap<>();
3645
}
3746

3847
@Override
@@ -44,4 +53,54 @@ protected int validChunkCount() {
4453
public int count() {
4554
return vectorCount;
4655
}
56+
57+
private synchronized VectorFloat<?> getOrCreateCodebookPartialSums(VectorSimilarityFunction vsf) {
58+
return codebookPartialSumsMap.computeIfAbsent(vsf, pq::createCodebookPartialSums);
59+
}
60+
61+
@Override
62+
public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, VectorSimilarityFunction similarityFunction) {
63+
final int subspaceCount = pq.getSubspaceCount();
64+
var node1Chunk = getChunk(node1);
65+
var node1Offset = getOffsetInChunk(node1);
66+
int clusterCount = pq.getClusterCount();
67+
68+
VectorFloat<?> codebookPartialSums = getOrCreateCodebookPartialSums(similarityFunction);
69+
70+
switch (similarityFunction) {
71+
case DOT_PRODUCT:
72+
return (node2) -> {
73+
var node2Chunk = getChunk(node2);
74+
var node2Offset = getOffsetInChunk(node2);
75+
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
76+
float sum = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node1Chunk, node1Offset, node2Chunk, node2Offset, clusterCount);
77+
// scale to [0, 1]
78+
return (1 + sum) / 2;
79+
};
80+
case COSINE:
81+
float norm1 = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node1Chunk, node1Offset, node1Chunk, node1Offset, clusterCount);
82+
return (node2) -> {
83+
var node2Chunk = getChunk(node2);
84+
var node2Offset = getOffsetInChunk(node2);
85+
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
86+
float sum = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node1Chunk, node1Offset, node2Chunk, node2Offset, clusterCount);
87+
float norm2 = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node2Chunk, node2Offset, node2Chunk, node2Offset, clusterCount);
88+
float cosine = sum / (float) Math.sqrt(norm1 * norm2);
89+
// scale to [0, 1]
90+
return (1 + cosine) / 2;
91+
};
92+
case EUCLIDEAN:
93+
return (node2) -> {
94+
var node2Chunk = getChunk(node2);
95+
var node2Offset = getOffsetInChunk(node2);
96+
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
97+
float sum = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node1Chunk, node1Offset, node2Chunk, node2Offset, clusterCount);
98+
99+
// scale to [0, 1]
100+
return 1 / (1 + sum);
101+
};
102+
default:
103+
throw new IllegalArgumentException("Unsupported similarity function " + similarityFunction);
104+
}
105+
}
47106
}

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
2323
import io.github.jbellis.jvector.util.Accountable;
2424
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
25+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
2526
import io.github.jbellis.jvector.vector.VectorUtil;
2627
import io.github.jbellis.jvector.vector.VectorizationProvider;
2728
import io.github.jbellis.jvector.vector.types.ByteSequence;
@@ -36,6 +37,7 @@
3637
import java.util.concurrent.ForkJoinPool;
3738
import java.util.concurrent.ThreadLocalRandom;
3839
import java.util.concurrent.atomic.AtomicReference;
40+
import java.util.function.Supplier;
3941
import java.util.logging.Logger;
4042
import java.util.stream.Collectors;
4143
import java.util.stream.IntStream;
@@ -586,6 +588,35 @@ public void write(DataOutput out, int version) throws IOException
586588
}
587589
}
588590

591+
/**
592+
* Creates a vector to hold partial sums for a single codebook.
593+
* The partial sums are the dot products of each subvector centroid in the codebook with the other subvector centroids.
594+
* Since the dot product is commutative, we only need to store the upper triangle of the matrix.
595+
* There are M codebooks, and each codebook has k centroids, so the total number of partial sums is M * k * (k+1) / 2.
596+
*
597+
* @return a vector to hold partial sums for a single codebook
598+
*/
599+
public VectorFloat<?> createCodebookPartialSums(VectorSimilarityFunction vectorSimilarityFunction) {
600+
VectorFloat<?> partialSums = vectorTypeSupport.createFloatVector(getSubspaceCount() * getClusterCount() * (getClusterCount() + 1) / 2);
601+
int index = 0;
602+
for (int m = 0; m < M; m++) {
603+
int size = subvectorSizesAndOffsets[m][0];
604+
var codebook = codebooks[m];
605+
for (int i = 0; i < clusterCount; i++) {
606+
for (int j = i; j < clusterCount; j++) {
607+
608+
float sum = vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN ?
609+
VectorUtil.squareL2Distance(codebook, i * size, codebook, j * size, size) :
610+
VectorUtil.dotProduct(codebook, i * size, codebook, j * size, size);
611+
612+
partialSums.set(index++, sum);
613+
}
614+
}
615+
}
616+
617+
return partialSums;
618+
}
619+
589620
@Override
590621
public int compressorSize() {
591622
int size = 0;
@@ -734,4 +765,8 @@ private static void checkClusterCount(int clusterCount) {
734765
LOG.warning("Using less than 256 PQ clusters will not reduce the memory footprint.");
735766
}
736767
}
768+
769+
public int getOriginalDimension() {
770+
return originalDimension;
771+
}
737772
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,36 @@ public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> b
308308
return sum;
309309
}
310310

311+
@Override
312+
public float assembleAndSumPQ(
313+
VectorFloat<?> codebookPartialSums,
314+
int subspaceCount, // = M
315+
ByteSequence<?> vector1Ordinals,
316+
int vector1OrdinalOffset,
317+
ByteSequence<?> vector2Ordinals,
318+
int vector2OrdinalOffset,
319+
int clusterCount
320+
) {
321+
final int k = clusterCount;
322+
final int blockSize = k * (k + 1) / 2;
323+
float res = 0f;
324+
325+
for (int i = 0; i < subspaceCount; i++) {
326+
int c1 = Byte.toUnsignedInt(vector1Ordinals.get(i + vector1OrdinalOffset));
327+
int c2 = Byte.toUnsignedInt(vector2Ordinals.get(i + vector2OrdinalOffset));
328+
int r = Math.min(c1, c2);
329+
int c = Math.max(c1, c2);
330+
331+
int offsetRow = r * k - (r * (r - 1) / 2);
332+
int idxInBlock = offsetRow + (c - r);
333+
int base = i * blockSize;
334+
335+
res += codebookPartialSums.get(base + idxInBlock);
336+
}
337+
338+
return res;
339+
}
340+
311341
@Override
312342
public int hammingDistance(long[] v1, long[] v2) {
313343
int hd = 0;

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ public static float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequen
170170
return impl.assembleAndSum(data, dataBase, dataOffsets, dataOffsetsOffset, dataOffsetsLength);
171171
}
172172

173+
public static float assembleAndSumPQ(VectorFloat<?> data, int subspaceCount, ByteSequence<?> dataOffsets1, int dataOffsetsOffset1, ByteSequence<?> dataOffsets2, int dataOffsetsOffset2, int clusterCount) {
174+
return impl.assembleAndSumPQ(data, subspaceCount, dataOffsets1, dataOffsetsOffset1, dataOffsets2, dataOffsetsOffset2, clusterCount);
175+
}
176+
173177
public static void bulkShuffleQuantizedSimilarity(ByteSequence<?> shuffles, int codebookCount, ByteSequence<?> quantizedPartials, float delta, float minDistance, VectorFloat<?> results, VectorSimilarityFunction vsf) {
174178
impl.bulkShuffleQuantizedSimilarity(shuffles, codebookCount, quantizedPartials, delta, minDistance, vsf, results);
175179
}
@@ -250,5 +254,4 @@ public static float nvqLoss(VectorFloat<?> vector, float growthRate, float midpo
250254
public static float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits) {
251255
return impl.nvqUniformLoss(vector, minValue, maxValue, nBits);
252256
}
253-
254257
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
package io.github.jbellis.jvector.vector;
2626

27+
import io.github.jbellis.jvector.quantization.ProductQuantization;
2728
import io.github.jbellis.jvector.vector.types.ByteSequence;
2829
import io.github.jbellis.jvector.vector.types.VectorFloat;
2930

@@ -113,6 +114,22 @@ public interface VectorUtilSupport {
113114
*/
114115
float assembleAndSum(VectorFloat<?> data, int baseIndex, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength);
115116

117+
/**
118+
* Calculates the distance between 2 vectors, which were quantized using Product Quantization, using a precomputed table of partial results
119+
*
120+
* See {@link ProductQuantization#createCodebookPartialSums(VectorSimilarityFunction)}
121+
*
122+
* @param codebookPartialSums the vector of all PQ
123+
* @param subspaceCount the number of PQ subspaces
124+
* @param vector1Ordinals Specifies which centroid vector is used for each of node1's subvectors
125+
* @param vector1OrdinalOffset the offset into the vector1Ordinals ByteSequence for node1 (in case vector1Ordinals is a chunk of many nodes)
126+
* @param node2Ordinals Specifies which centroid vector is used for each of node2's subvectors
127+
* @param node2OrdinalOffset the offset into the vector1Ordinals ByteSequence for node2 (in case vector1Ordinals is a chunk of many nodes)
128+
* @param clusterCount the number of PQ clusters per subvector in the codebook
129+
* @return the sum of the vectors
130+
*/
131+
float assembleAndSumPQ(VectorFloat<?> codebookPartialSums, int subspaceCount, ByteSequence<?> vector1Ordinals, int vector1OrdinalOffset, ByteSequence<?> node2Ordinals, int node2OrdinalOffset, int clusterCount);
132+
116133
int hammingDistance(long[] v1, long[] v2);
117134

118135
// default implementation used here because Panama SIMD can't express necessary SIMD operations and degrades to scalar

jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@ public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> b
8080
}
8181

8282

83+
@Override
84+
public float assembleAndSumPQ(
85+
VectorFloat<?> codebookPartialSums,
86+
int subspaceCount, // = M
87+
ByteSequence<?> vector1Ordinals,
88+
int vector1OrdinalOffset,
89+
ByteSequence<?> vector2Ordinals,
90+
int vector2OrdinalOffset,
91+
int clusterCount // = k
92+
) {
93+
//Use the non-panama solution for now
94+
return assembleAndSumPQ_128(codebookPartialSums, subspaceCount, vector1Ordinals, vector1OrdinalOffset, vector2Ordinals, vector2OrdinalOffset, clusterCount);
95+
}
96+
8397
@Override
8498
public void calculatePartialSums(VectorFloat<?> codebook, int codebookBase, int size, int clusterCount, VectorFloat<?> query, int queryOffset, VectorSimilarityFunction vsf, VectorFloat<?> partialSums) {
8599
switch (vsf) {

jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestProductQuantization.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.nio.file.Files;
3535
import java.util.Arrays;
3636
import java.util.List;
37+
import java.util.concurrent.ForkJoinPool;
3738
import java.util.stream.Collectors;
3839
import java.util.stream.IntStream;
3940

@@ -406,5 +407,31 @@ public void testPQLayoutEdgeCases() {
406407
System.out.println("Test completed successfully");
407408
}
408409

410+
@Test
411+
public void testPQCodebookSums() {
412+
// Generate a PQ for random 2D vectors
413+
var vectors = createRandomVectors(10000, 384);
414+
var pq = ProductQuantization.compute(new ListRandomAccessVectorValues(vectors, 384), 48, 256, false);
415+
416+
MutablePQVectors pqm = new MutablePQVectors(pq);
417+
418+
// build the index vector-at-a-time (on disk)
419+
for (int ordinal = 0; ordinal < vectors.size(); ordinal++)
420+
{
421+
VectorFloat<?> v = vectors.get(ordinal);
422+
// compress the new vector and add it to the PQVectors
423+
pqm.encodeAndSet(ordinal, v);
424+
}
409425

426+
for (VectorSimilarityFunction vsf : VectorSimilarityFunction.values()) {
427+
var sf = pqm.diversityFunctionFor(10, vsf);
428+
429+
ImmutablePQVectors pqi = ImmutablePQVectors.encodeAndBuild(pq, vectors.size(), new ListRandomAccessVectorValues(vectors, 384), ForkJoinPool.commonPool());
430+
var sf2 = pqi.diversityFunctionFor(10, vsf);
431+
432+
for (int i = 0; i < vectors.size(); i++) {
433+
assertEquals(vsf.name(), sf.similarityTo(i), sf2.similarityTo(i), 1e-6);
434+
}
435+
}
436+
}
410437
}

0 commit comments

Comments
 (0)