Skip to content

Commit 279e046

Browse files
committed
update pgvector Follow reviewer's advice
Signed-off-by: Abirdcfly <[email protected]>
1 parent 46b24e7 commit 279e046

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

vectorstores/pgvector/pgvector.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ PRIMARY KEY (uuid))`, s.embeddingTableName, s.collectionTableName)
151151

152152
func (s Store) AddDocuments(ctx context.Context, docs []schema.Document, options ...vectorstores.Option) error {
153153
opts := s.getOptions(options...)
154-
if opts.Embedder != nil || opts.ScoreThreshold != 0 || opts.Filters != nil || opts.NameSpace != "" {
154+
if opts.ScoreThreshold != 0 || opts.Filters != nil || opts.NameSpace != "" {
155155
return ErrUnsupportedOptions
156156
}
157157

@@ -160,7 +160,11 @@ func (s Store) AddDocuments(ctx context.Context, docs []schema.Document, options
160160
texts = append(texts, doc.PageContent)
161161
}
162162

163-
vectors, err := s.embedder.EmbedDocuments(ctx, texts)
163+
embedder := s.embedder
164+
if opts.Embedder != nil {
165+
embedder = opts.Embedder
166+
}
167+
vectors, err := embedder.EmbedDocuments(ctx, texts)
164168
if err != nil {
165169
return err
166170
}
@@ -186,11 +190,8 @@ func (s Store) SimilaritySearch(
186190
numDocuments int,
187191
options ...vectorstores.Option,
188192
) ([]schema.Document, error) {
189-
collectionName := s.collectionName
190193
opts := s.getOptions(options...)
191-
if nameSpace := s.getNameSpace(opts); nameSpace != "" {
192-
collectionName = nameSpace
193-
}
194+
collectionName := s.getNameSpace(opts)
194195
scoreThreshold, err := s.getScoreThreshold(opts)
195196
if err != nil {
196197
return nil, err
@@ -199,7 +200,11 @@ func (s Store) SimilaritySearch(
199200
if err != nil {
200201
return nil, err
201202
}
202-
embedder, err := s.embedder.EmbedQuery(ctx, query)
203+
embedder := s.embedder
204+
if opts.Embedder != nil {
205+
embedder = opts.Embedder
206+
}
207+
embedderData, err := embedder.EmbedQuery(ctx, query)
203208
if err != nil {
204209
return nil, err
205210
}
@@ -236,7 +241,7 @@ LIMIT %d`, s.embeddingTableName,
236241
s.embeddingTableName,
237242
s.collectionTableName, s.embeddingTableName, s.collectionTableName, s.collectionTableName, collectionName,
238243
whereQuery, numDocuments)
239-
rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedder))
244+
rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedderData))
240245
if err != nil {
241246
return nil, err
242247
}
@@ -295,6 +300,8 @@ func (s Store) createOrGetCollection(ctx context.Context) (string, error) {
295300
return collectionUUID, nil
296301
}
297302

303+
// getOptions applies given options to default Options and returns it
304+
// This uses options pattern so clients can easily pass options without changing function signature.
298305
func (s Store) getOptions(options ...vectorstores.Option) vectorstores.Options {
299306
opts := vectorstores.Options{}
300307
for _, opt := range options {
@@ -307,7 +314,7 @@ func (s Store) getNameSpace(opts vectorstores.Options) string {
307314
if opts.NameSpace != "" {
308315
return opts.NameSpace
309316
}
310-
return ""
317+
return s.collectionName
311318
}
312319

313320
func (s Store) getScoreThreshold(opts vectorstores.Options) (float32, error) {
@@ -317,6 +324,7 @@ func (s Store) getScoreThreshold(opts vectorstores.Options) (float32, error) {
317324
return opts.ScoreThreshold, nil
318325
}
319326

327+
// getFilters return metadata filters, now only support map[key]value pattern
320328
// TODO: should support more types like {"key1": {"key2":"values2"}} or {"key": ["value1", "values2"]}.
321329
func (s Store) getFilters(opts vectorstores.Options) (map[string]any, error) {
322330
if opts.Filters != nil {

vectorstores/pgvector/pgvector_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func preCheckEnvSetting(t *testing.T) {
3030
}
3131
}
3232

33-
func getTestCollectionName() string {
33+
func makeNewCollectionName() string {
3434
return fmt.Sprintf("test-collection-%s", uuid.New().String())
3535
}
3636

@@ -54,7 +54,7 @@ func TestPgvectorStoreRest(t *testing.T) {
5454
ctx,
5555
pgvector.WithEmbedder(e),
5656
pgvector.WithPreDeleteCollection(true),
57-
pgvector.WithCollectionName(getTestCollectionName()),
57+
pgvector.WithCollectionName(makeNewCollectionName()),
5858
)
5959
require.NoError(t, err)
6060

@@ -89,7 +89,7 @@ func TestPgvectorStoreRestWithScoreThreshold(t *testing.T) {
8989
ctx,
9090
pgvector.WithEmbedder(e),
9191
pgvector.WithPreDeleteCollection(true),
92-
pgvector.WithCollectionName(getTestCollectionName()),
92+
pgvector.WithCollectionName(makeNewCollectionName()),
9393
)
9494
require.NoError(t, err)
9595

@@ -143,7 +143,7 @@ func TestSimilaritySearchWithInvalidScoreThreshold(t *testing.T) {
143143
ctx,
144144
pgvector.WithEmbedder(e),
145145
pgvector.WithPreDeleteCollection(true),
146-
pgvector.WithCollectionName(getTestCollectionName()),
146+
pgvector.WithCollectionName(makeNewCollectionName()),
147147
)
148148
require.NoError(t, err)
149149

@@ -194,7 +194,7 @@ func TestPgvectorAsRetriever(t *testing.T) {
194194
ctx,
195195
pgvector.WithEmbedder(e),
196196
pgvector.WithPreDeleteCollection(true),
197-
pgvector.WithCollectionName(getTestCollectionName()),
197+
pgvector.WithCollectionName(makeNewCollectionName()),
198198
)
199199
require.NoError(t, err)
200200

@@ -236,7 +236,7 @@ func TestPgvectorAsRetrieverWithScoreThreshold(t *testing.T) {
236236
ctx,
237237
pgvector.WithEmbedder(e),
238238
pgvector.WithPreDeleteCollection(true),
239-
pgvector.WithCollectionName(getTestCollectionName()),
239+
pgvector.WithCollectionName(makeNewCollectionName()),
240240
)
241241
require.NoError(t, err)
242242

@@ -283,7 +283,7 @@ func TestPgvectorAsRetrieverWithMetadataFilterNotSelected(t *testing.T) {
283283
ctx,
284284
pgvector.WithEmbedder(e),
285285
pgvector.WithPreDeleteCollection(true),
286-
pgvector.WithCollectionName(getTestCollectionName()),
286+
pgvector.WithCollectionName(makeNewCollectionName()),
287287
)
288288
require.NoError(t, err)
289289

@@ -357,7 +357,7 @@ func TestPgvectorAsRetrieverWithMetadataFilters(t *testing.T) {
357357
ctx,
358358
pgvector.WithEmbedder(e),
359359
pgvector.WithPreDeleteCollection(true),
360-
pgvector.WithCollectionName(getTestCollectionName()),
360+
pgvector.WithCollectionName(makeNewCollectionName()),
361361
)
362362
require.NoError(t, err)
363363

0 commit comments

Comments
 (0)