Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions LiteDB.Tests/LiteDB.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="FluentAssertions" Version="6.12.2" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
<PackageReference Include="FluentAssertions" Version="6.12.2" />
<PackageReference Include="MathNet.Numerics" Version="5.0.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
<PackageReference Include="xunit" Version="2.9.2" />
<PackageReference Include="xunit.runner.console" Version="2.9.2">
<PrivateAssets>all</PrivateAssets>
Expand Down
301 changes: 297 additions & 4 deletions LiteDB.Tests/Query/VectorIndex_Tests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using FluentAssertions;
using LiteDB;
using LiteDB.Engine;
using MathNet.Numerics.LinearAlgebra;
using System;
using System.Collections.Generic;
using System.IO;
Expand Down Expand Up @@ -82,6 +83,89 @@ private static int CountNodes(Snapshot snapshot, PageAddress root)
return count;
}

private static float[] CreateVector(Random random, int dimensions)
{
var vector = new float[dimensions];
var hasNonZero = false;

for (var i = 0; i < dimensions; i++)
{
var value = (float)(random.NextDouble() * 2d - 1d);
vector[i] = value;

if (!hasNonZero && Math.Abs(value) > 1e-6f)
{
hasNonZero = true;
}
}

if (!hasNonZero)
{
vector[random.Next(dimensions)] = 1f;
}

return vector;
}

private static (double Distance, double Similarity) ComputeReferenceMetrics(float[] candidate, float[] target, VectorDistanceMetric metric)
{
var builder = Vector<double>.Build;
var candidateVector = builder.DenseOfEnumerable(candidate.Select(v => (double)v));
var targetVector = builder.DenseOfEnumerable(target.Select(v => (double)v));

switch (metric)
{
case VectorDistanceMetric.Cosine:
var candidateNorm = candidateVector.L2Norm();
var targetNorm = targetVector.L2Norm();

if (candidateNorm == 0d || targetNorm == 0d)
{
return (double.NaN, double.NaN);
}

var cosineSimilarity = candidateVector.DotProduct(targetVector) / (candidateNorm * targetNorm);
return (1d - cosineSimilarity, double.NaN);

case VectorDistanceMetric.Euclidean:
return ((candidateVector - targetVector).L2Norm(), double.NaN);

case VectorDistanceMetric.DotProduct:
var dot = candidateVector.DotProduct(targetVector);
return (-dot, dot);

default:
throw new ArgumentOutOfRangeException(nameof(metric), metric, null);
}
}

private static List<(int Id, double Distance, double Similarity)> ComputeExpectedRanking(
IEnumerable<VectorDocument> documents,
float[] target,
VectorDistanceMetric metric,
int? limit = null)
{
var ordered = documents
.Select(doc =>
{
var (distance, similarity) = ComputeReferenceMetrics(doc.Embedding, target, metric);
return (doc.Id, Distance: distance, Similarity: similarity);
})
.Where(result => metric == VectorDistanceMetric.DotProduct
? !double.IsNaN(result.Similarity)
: !double.IsNaN(result.Distance))
.OrderBy(result => metric == VectorDistanceMetric.DotProduct ? -result.Similarity : result.Distance)
.ThenBy(result => result.Id)
.ToList();

if (limit.HasValue)
{
ordered = ordered.Take(limit.Value).ToList();
}

return ordered;
}

[Fact]
public void EnsureVectorIndex_CreatesAndReuses()
{
Expand Down Expand Up @@ -347,9 +431,12 @@ public void VectorIndex_Search_Prunes_Node_Visits()
using var db = new LiteDatabase(":memory:");
var collection = db.GetCollection<VectorDocument>("vectors");

const int nearClusterSize = 64;
const int farClusterSize = 64;

var documents = new List<VectorDocument>();

for (var i = 0; i < 32; i++)
for (var i = 0; i < nearClusterSize; i++)
{
documents.Add(new VectorDocument
{
Expand All @@ -359,17 +446,18 @@ public void VectorIndex_Search_Prunes_Node_Visits()
});
}

for (var i = 0; i < 32; i++)
for (var i = 0; i < farClusterSize; i++)
{
documents.Add(new VectorDocument
{
Id = i + 33,
Id = i + nearClusterSize + 1,
Embedding = new[] { -1f, 2f + i / 100f },
Flag = false
});
}

collection.Insert(documents);
collection.Count().Should().Be(documents.Count);

collection.EnsureIndex(
"embedding_idx",
Expand All @@ -389,7 +477,8 @@ public void VectorIndex_Search_Prunes_Node_Visits()
});

stats.Total.Should().BeGreaterThan(stats.Visited);
stats.Matches.Should().OnlyContain(id => id < 32);
stats.Total.Should().BeGreaterOrEqualTo(nearClusterSize);
stats.Matches.Should().OnlyContain(id => id <= nearClusterSize);
}

[Fact]
Expand Down Expand Up @@ -467,5 +556,209 @@ public void VectorIndex_PersistsNodes_WhenDocumentsChange()
return 0;
});
}

[Theory]
[InlineData(VectorDistanceMetric.Cosine)]
[InlineData(VectorDistanceMetric.Euclidean)]
[InlineData(VectorDistanceMetric.DotProduct)]
public void VectorDistance_Computation_MatchesMathNet(VectorDistanceMetric metric)
{
var random = new Random(1789);
const int dimensions = 6;

for (var i = 0; i < 20; i++)
{
var candidate = CreateVector(random, dimensions);
var target = CreateVector(random, dimensions);

var distance = VectorIndexService.ComputeDistance(candidate, target, metric, out var similarity);
var (expectedDistance, expectedSimilarity) = ComputeReferenceMetrics(candidate, target, metric);

if (double.IsNaN(expectedDistance))
{
double.IsNaN(distance).Should().BeTrue();
}
else
{
distance.Should().BeApproximately(expectedDistance, 1e-6);
}

if (double.IsNaN(expectedSimilarity))
{
double.IsNaN(similarity).Should().BeTrue();
}
else
{
similarity.Should().BeApproximately(expectedSimilarity, 1e-6);
}
}

if (metric == VectorDistanceMetric.Cosine)
{
var zero = new float[dimensions];
var other = CreateVector(random, dimensions);

var distance = VectorIndexService.ComputeDistance(zero, other, metric, out var similarity);

double.IsNaN(distance).Should().BeTrue();
double.IsNaN(similarity).Should().BeTrue();
}
}

[Theory]
[InlineData(VectorDistanceMetric.Cosine)]
[InlineData(VectorDistanceMetric.Euclidean)]
[InlineData(VectorDistanceMetric.DotProduct)]
public void VectorIndex_Search_MatchesReferenceRanking(VectorDistanceMetric metric)
{
using var db = new LiteDatabase(":memory:");
var collection = db.GetCollection<VectorDocument>("vectors");

var random = new Random(4242);
const int dimensions = 6;

var documents = Enumerable.Range(1, 32)
.Select(i => new VectorDocument
{
Id = i,
Embedding = CreateVector(random, dimensions),
Flag = i % 2 == 0
})
.ToList();

collection.Insert(documents);

collection.EnsureIndex(
"embedding_idx",
BsonExpression.Create("$.Embedding"),
new VectorIndexOptions((ushort)dimensions, metric));

var target = CreateVector(random, dimensions);
foreach (var limit in new[] { 5, 12 })
{
var expectedTop = ComputeExpectedRanking(documents, target, metric, limit);

var actual = InspectVectorIndex(db, "vectors", (snapshot, collation, metadata) =>
{
var service = new VectorIndexService(snapshot, collation);
return service.Search(metadata, target, double.MaxValue, limit)
.Select(result =>
{
var mapped = BsonMapper.Global.ToObject<VectorDocument>(result.Document);
return (Id: mapped.Id, Score: result.Distance);
})
.ToList();
});

actual.Should().HaveCount(expectedTop.Count);

for (var i = 0; i < expectedTop.Count; i++)
{
actual[i].Id.Should().Be(expectedTop[i].Id);

if (metric == VectorDistanceMetric.DotProduct)
{
actual[i].Score.Should().BeApproximately(expectedTop[i].Similarity, 1e-6);
}
else
{
actual[i].Score.Should().BeApproximately(expectedTop[i].Distance, 1e-6);
}
}
}
}

[Theory]
[InlineData(VectorDistanceMetric.Cosine)]
[InlineData(VectorDistanceMetric.Euclidean)]
[InlineData(VectorDistanceMetric.DotProduct)]
public void WhereNear_MatchesReferenceOrdering(VectorDistanceMetric metric)
{
using var db = new LiteDatabase(":memory:");
var collection = db.GetCollection<VectorDocument>("vectors");

var random = new Random(9182);
const int dimensions = 6;

var documents = Enumerable.Range(1, 40)
.Select(i => new VectorDocument
{
Id = i,
Embedding = CreateVector(random, dimensions),
Flag = i % 3 == 0
})
.ToList();

collection.Insert(documents);

collection.EnsureIndex(
"embedding_idx",
BsonExpression.Create("$.Embedding"),
new VectorIndexOptions((ushort)dimensions, metric));

var target = CreateVector(random, dimensions);
const int limit = 12;

var query = collection.Query()
.WhereNear(x => x.Embedding, target, double.MaxValue)
.Limit(limit);

var plan = query.GetPlan();
plan["index"]["mode"].AsString.Should().Be("VECTOR INDEX SEARCH");

var results = query.ToArray();

results.Should().HaveCount(limit);

var searchIds = InspectVectorIndex(db, "vectors", (snapshot, collation, metadata) =>
{
var service = new VectorIndexService(snapshot, collation);
return service.Search(metadata, target, double.MaxValue, limit)
.Select(result => BsonMapper.Global.ToObject<VectorDocument>(result.Document).Id)
.ToArray();
});

results.Select(x => x.Id).Should().Equal(searchIds);
}

[Theory]
[InlineData(VectorDistanceMetric.Cosine)]
[InlineData(VectorDistanceMetric.Euclidean)]
[InlineData(VectorDistanceMetric.DotProduct)]
public void TopKNear_MatchesReferenceOrdering(VectorDistanceMetric metric)
{
using var db = new LiteDatabase(":memory:");
var collection = db.GetCollection<VectorDocument>("vectors");

var random = new Random(5461);
const int dimensions = 6;

var documents = Enumerable.Range(1, 48)
.Select(i => new VectorDocument
{
Id = i,
Embedding = CreateVector(random, dimensions),
Flag = i % 4 == 0
})
.ToList();

collection.Insert(documents);

collection.EnsureIndex(
"embedding_idx",
BsonExpression.Create("$.Embedding"),
new VectorIndexOptions((ushort)dimensions, metric));

var target = CreateVector(random, dimensions);
const int limit = 7;
var expected = ComputeExpectedRanking(documents, target, metric, limit);

var results = collection.Query()
.TopKNear(x => x.Embedding, target, limit)
.ToArray();

results.Should().HaveCount(expected.Count);
results.Select(x => x.Id).Should().Equal(expected.Select(x => x.Id));
}
}
}
Loading