@@ -7,6 +7,7 @@ package faiss
7
7
#include <faiss/c_api/IndexIVF_c.h>
8
8
#include <faiss/c_api/IndexIVF_c_ex.h>
9
9
#include <faiss/c_api/IndexBinary_c.h>
10
+ #include <faiss/c_api/IndexBinaryIVF_c.h>
10
11
#include <faiss/c_api/index_factory_c.h>
11
12
#include <faiss/c_api/MetaIndexes_c.h>
12
13
#include <faiss/c_api/impl/AuxIndexStructures_c.h>
@@ -54,6 +55,10 @@ type BinaryIndex interface {
54
55
SearchBinaryWithIDs (x []uint8 , k int64 , include []int64 , params json.RawMessage ) ([]int32 , []int64 , error )
55
56
SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 ,
56
57
labels []int64 , err error )
58
+
59
+ ObtainClusterVectorCountsFromIVFIndex (vecIDs []int64 ) (map [int64 ]int64 , error )
60
+ ObtainClustersWithDistancesFromIVFIndex (x []uint8 , centroidIDs []int64 ) (
61
+ []int64 , []int32 , error )
57
62
}
58
63
59
64
// FloatIndex defines methods specific to float-based FAISS indexes
@@ -156,6 +161,62 @@ func (idx *BinaryIndexImpl) Close() {
156
161
}
157
162
}
158
163
164
+ func (idx * BinaryIndexImpl ) ObtainClusterVectorCountsFromIVFIndex (vecIDs []int64 ) (map [int64 ]int64 , error ) {
165
+ if ! idx .IsIVFIndex () {
166
+ return nil , fmt .Errorf ("index is not an IVF index" )
167
+ }
168
+ clusterIDs := make ([]int64 , len (vecIDs ))
169
+ if c := C .faiss_get_lists_for_keys_binary (
170
+ idx .indexPtr ,
171
+ (* C .idx_t )(unsafe .Pointer (& vecIDs [0 ])),
172
+ (C .size_t )(len (vecIDs )),
173
+ (* C .idx_t )(unsafe .Pointer (& clusterIDs [0 ])),
174
+ ); c != 0 {
175
+ return nil , getLastError ()
176
+ }
177
+ rv := make (map [int64 ]int64 , len (vecIDs ))
178
+ for _ , v := range clusterIDs {
179
+ rv [v ]++
180
+ }
181
+ return rv , nil
182
+ }
183
+
184
+ func (idx * BinaryIndexImpl ) ObtainClustersWithDistancesFromIVFIndex (x []uint8 , centroidIDs []int64 ) (
185
+ []int64 , []int32 , error ) {
186
+ // Selector to include only the centroids whose IDs are part of 'centroidIDs'.
187
+ includeSelector , err := NewIDSelectorBatch (centroidIDs )
188
+ if err != nil {
189
+ return nil , nil , err
190
+ }
191
+ defer includeSelector .Delete ()
192
+
193
+ params , err := NewSearchParams (idx , json.RawMessage {}, includeSelector .Get (), nil )
194
+ if err != nil {
195
+ return nil , nil , err
196
+ }
197
+ defer params .Delete ()
198
+
199
+ // Populate these with the centroids and their distances.
200
+ centroids := make ([]int64 , len (centroidIDs ))
201
+ centroidDistances := make ([]int32 , len (centroidIDs ))
202
+
203
+ n := len (x ) / idx .D ()
204
+
205
+ c := C .faiss_Search_closest_eligible_centroids_binary (
206
+ idx .indexPtr ,
207
+ (C .idx_t )(n ),
208
+ (* C .uint8_t )(& x [0 ]),
209
+ (C .idx_t )(len (centroidIDs )),
210
+ (* C .int32_t )(& centroidDistances [0 ]),
211
+ (* C .idx_t )(& centroids [0 ]),
212
+ params .sp )
213
+ if c != 0 {
214
+ return nil , nil , getLastError ()
215
+ }
216
+
217
+ return centroids , centroidDistances , nil
218
+ }
219
+
159
220
func (idx * BinaryIndexImpl ) Size () uint64 {
160
221
return 0
161
222
}
@@ -263,7 +324,7 @@ func (idx *BinaryIndexImpl) Train(vectors []uint8) error {
263
324
}
264
325
265
326
func (idx * BinaryIndexImpl ) SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 , labels []int64 , err error ) {
266
- if len (exclude ) == 0 && params == nil {
327
+ if len (exclude ) == 0 && len ( params ) == 0 {
267
328
return idx .SearchBinary (x , k )
268
329
}
269
330
0 commit comments