提交 f5cb7fa2 编写于 作者: X xj.lin

add id

上级 cfbe86df
...@@ -36,9 +36,10 @@ class DefaultIndex(Index): ...@@ -36,9 +36,10 @@ class DefaultIndex(Index):
pass pass
def build(self, d, vectors, vector_ids, DEVICE=INDEXDEVICES.CPU): def build(self, d, vectors, vector_ids, DEVICE=INDEXDEVICES.CPU):
index = faiss.IndexFlatL2(d) # trained index = faiss.IndexFlatL2(d)
index.add(vectors) index2 = faiss.IndexIDMap(index)
return index index2.add_with_ids(vectors, vector_ids)
return index2
class LowMemoryIndex(Index): class LowMemoryIndex(Index):
......
...@@ -16,14 +16,16 @@ class TestBuildIndex(unittest.TestCase): ...@@ -16,14 +16,16 @@ class TestBuildIndex(unittest.TestCase):
nb = 10000 nb = 10000
nq = 100 nq = 100
_, xb, xq = get_dataset(d, nb, 500, nq) _, xb, xq = get_dataset(d, nb, 500, nq)
ids = np.arange(xb.shape[0])
# Expected result # Expected result
index = faiss.IndexFlatL2(d) index = faiss.IndexFlatL2(d)
index.add(xb) index2 = faiss.IndexIDMap(index)
index2.add_with_ids(xb, ids)
Dref, Iref = index.search(xq, 5) Dref, Iref = index.search(xq, 5)
builder = DefaultIndex() builder = DefaultIndex()
index2 = builder.build(d, xb) index2 = builder.build(d, xb, ids)
Dnew, Inew = index2.search(xq, 5) Dnew, Inew = index2.search(xq, 5)
assert np.all(Dnew == Dref) and np.all(Inew == Iref) assert np.all(Dnew == Dref) and np.all(Inew == Iref)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册