From dd1c332d04a4d3a0debf43bc902d284331391f5c Mon Sep 17 00:00:00 2001 From: "Kokoori, Shylaja" Date: Thu, 28 Aug 2025 16:07:01 -0700 Subject: [PATCH] Optimizes reducelanes in diversityCalculation of PQVectors, for Euclidean function --- .../jvector/quantization/PQVectors.java | 8 +- .../vector/DefaultVectorUtilSupport.java | 14 ++ .../jbellis/jvector/vector/VectorUtil.java | 5 + .../jvector/vector/VectorUtilSupport.java | 2 +- .../vector/PanamaVectorUtilSupport.java | 166 ++++++++++++++++++ 5 files changed, 187 insertions(+), 8 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index f66a2c6e4..8253f1e6f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -298,13 +298,7 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve var node2Chunk = getChunk(node2); var node2Offset = getOffsetInChunk(node2); // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points - float sum = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); - int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength); - } + float sum = VectorUtil.pqDiversityEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); // scale to [0, 1] return 1 / (1 + sum); }; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java index 867e1c85d..91d3445e0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java @@ -584,4 +584,18 @@ public float nvqUniformLoss(VectorFloat vector, float minValue, float maxValu return squaredSum; } + @Override + public float pqDiversityEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + + sum += squareDistance(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength); + } + return sum; + + } + } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index e7e8b068f..c66f94307 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -227,6 +227,10 @@ public static float pqDecodedCosineSimilarity(ByteSequence encoded, int encod return impl.pqDecodedCosineSimilarity(encoded, encodedOffset, encodedLength, clusterCount, partialSums, aMagnitude, bMagnitude); } + public static float pqDiversityEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + return impl.pqDiversityEuclidean(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + public static float nvqDotProduct8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { return impl.nvqDotProduct8bit(vector, bytes, growthRate, midpoint, minValue, maxValue); } @@ -254,4 +258,5 @@ public static float nvqLoss(VectorFloat vector, float growthRate, float midpo public static float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { return impl.nvqUniformLoss(vector, minValue, maxValue, nBits); } + } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index cc1f74f1b..881d5e4c4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -337,5 +337,5 @@ default float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffs * @param nBits the number of bits per dimension */ float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits); - + float pqDiversityEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount); } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index eacb10866..d887f1a5c 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -1576,5 +1576,171 @@ public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude); } + + float pqDiversityEuclidean_64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64); + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.mul(diff).add(sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.mul(diff).add(sum); + } + // Process the tail + + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += diff * diff; + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqDiversityEuclidean_128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128); + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.mul(diff).add(sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.mul(diff).add(sum); + } + // Process the tail + + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += diff * diff; + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqDiversityEuclidean_256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.mul(diff).add(sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.mul(diff).add(sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += diff * diff; + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqDiversityEuclidean_512(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.mul(diff).add(sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.mul(diff).add(sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += diff * diff; + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + @Override + public float pqDiversityEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length() ) + { + return pqDiversityEuclidean_512( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length() ) + { + return pqDiversityEuclidean_256( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + //adding following two for completeness, will it get here? + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) + { + return pqDiversityEuclidean_128( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else + { + return pqDiversityEuclidean_64( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + } + }