Skip to content

Commit 20b3ae4

Browse files
bivf pre-filtering utils
1 parent bfb5436 commit 20b3ae4

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

index.go

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package faiss
77
#include <faiss/c_api/IndexIVF_c.h>
88
#include <faiss/c_api/IndexIVF_c_ex.h>
99
#include <faiss/c_api/IndexBinary_c.h>
10+
#include <faiss/c_api/IndexBinaryIVF_c.h>
1011
#include <faiss/c_api/index_factory_c.h>
1112
#include <faiss/c_api/MetaIndexes_c.h>
1213
#include <faiss/c_api/impl/AuxIndexStructures_c.h>
@@ -54,6 +55,10 @@ type BinaryIndex interface {
5455
SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) ([]int32, []int64, error)
5556
SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32,
5657
labels []int64, err error)
58+
59+
ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error)
60+
ObtainClustersWithDistancesFromIVFIndex(x []uint8, centroidIDs []int64) (
61+
[]int64, []int32, error)
5762
}
5863

5964
// FloatIndex defines methods specific to float-based FAISS indexes
@@ -156,6 +161,62 @@ func (idx *BinaryIndexImpl) Close() {
156161
}
157162
}
158163

164+
func (idx *BinaryIndexImpl) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) {
165+
if !idx.IsIVFIndex() {
166+
return nil, fmt.Errorf("index is not an IVF index")
167+
}
168+
clusterIDs := make([]int64, len(vecIDs))
169+
if c := C.faiss_get_lists_for_keys_binary(
170+
idx.indexPtr,
171+
(*C.idx_t)(unsafe.Pointer(&vecIDs[0])),
172+
(C.size_t)(len(vecIDs)),
173+
(*C.idx_t)(unsafe.Pointer(&clusterIDs[0])),
174+
); c != 0 {
175+
return nil, getLastError()
176+
}
177+
rv := make(map[int64]int64, len(vecIDs))
178+
for _, v := range clusterIDs {
179+
rv[v]++
180+
}
181+
return rv, nil
182+
}
183+
184+
func (idx *BinaryIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []uint8, centroidIDs []int64) (
185+
[]int64, []int32, error) {
186+
// Selector to include only the centroids whose IDs are part of 'centroidIDs'.
187+
includeSelector, err := NewIDSelectorBatch(centroidIDs)
188+
if err != nil {
189+
return nil, nil, err
190+
}
191+
defer includeSelector.Delete()
192+
193+
params, err := NewSearchParams(idx, json.RawMessage{}, includeSelector.Get(), nil)
194+
if err != nil {
195+
return nil, nil, err
196+
}
197+
defer params.Delete()
198+
199+
// Populate these with the centroids and their distances.
200+
centroids := make([]int64, len(centroidIDs))
201+
centroidDistances := make([]int32, len(centroidIDs))
202+
203+
n := len(x) / idx.D()
204+
205+
c := C.faiss_Search_closest_eligible_centroids_binary(
206+
idx.indexPtr,
207+
(C.idx_t)(n),
208+
(*C.uint8_t)(&x[0]),
209+
(C.idx_t)(len(centroidIDs)),
210+
(*C.int32_t)(&centroidDistances[0]),
211+
(*C.idx_t)(&centroids[0]),
212+
params.sp)
213+
if c != 0 {
214+
return nil, nil, getLastError()
215+
}
216+
217+
return centroids, centroidDistances, nil
218+
}
219+
159220
func (idx *BinaryIndexImpl) Size() uint64 {
160221
return 0
161222
}
@@ -263,7 +324,7 @@ func (idx *BinaryIndexImpl) Train(vectors []uint8) error {
263324
}
264325

265326
func (idx *BinaryIndexImpl) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, labels []int64, err error) {
266-
if len(exclude) == 0 && params == nil {
327+
if len(exclude) == 0 && len(params) == 0 {
267328
return idx.SearchBinary(x, k)
268329
}
269330

index_ivf.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package faiss
66
#include <faiss/c_api/Index_c.h>
77
#include <faiss/c_api/IndexIVF_c.h>
88
#include <faiss/c_api/IndexBinary_c.h>
9+
#include <faiss/c_api/IndexBinaryIVF_c.h>
910
#include <faiss/c_api/IndexIVF_c_ex.h>
1011
*/
1112
import "C"

0 commit comments

Comments
 (0)