提交 855d1c61 编写于 作者: X xj.lin

add unittest for build/search index

上级 d37f89dc
...@@ -7,7 +7,9 @@ import unittest ...@@ -7,7 +7,9 @@ import unittest
class TestBuildIndex(unittest.TestCase): class TestBuildIndex(unittest.TestCase):
def test_factory_method(self): def test_factory_method(self):
pass index_builder = FactoryIndex()
index = index_builder()
self.assertIsInstance(index, DefaultIndex)
def test_default_index(self): def test_default_index(self):
d = 64 d = 64
...@@ -30,15 +32,38 @@ class TestBuildIndex(unittest.TestCase): ...@@ -30,15 +32,38 @@ class TestBuildIndex(unittest.TestCase):
d = 64 d = 64
nb = 10000 nb = 10000
nq = 100 nq = 100
_, xb, xq = get_dataset(d, nb, 500, nq) nt = 500
xt, xb, xq = get_dataset(d, nb, nt, nq)
index = faiss.IndexFlatL2(d) index = faiss.IndexFlatL2(d)
index.add(xb) index.add(xb)
pass assert index.ntotal == nb
Index.increase(index, xt)
assert index.ntotal == nb + nt
def test_serialize(self): def test_serialize(self):
pass d = 64
nb = 10000
nq = 100
nt = 500
xt, xb, xq = get_dataset(d, nb, nt, nq)
index = faiss.IndexFlatL2(d)
index.add(xb)
Dref, Iref = index.search(xq, 5)
ar_data = Index.serialize(index)
reader = faiss.VectorIOReader()
faiss.copy_array_to_vector(ar_data, reader.data)
index2 = faiss.read_index(reader)
Dnew, Inew = index2.search(xq, 5)
assert np.all(Dnew == Dref) and np.all(Inew == Iref)
def get_dataset(d, nb, nt, nq): def get_dataset(d, nb, nt, nq):
......
from engine.controller import scheduler
scheduler.Scheduler.Search()
\ No newline at end of file
from ..search_index import *
import unittest
import numpy as np
class TestSearchSingleThread(unittest.TestCase):
def test_search_by_vectors(self):
d = 64
nb = 10000
nq = 100
_, xb, xq = get_dataset(d, nb, 500, nq)
index = faiss.IndexFlatL2(d)
index.add(xb)
# expect result
Dref, Iref = index.search(xq, 5)
searcher = FaissSearch(index)
result = searcher.search_by_vectors(xq, 5)
assert np.all(result.distance == Dref) \
and np.all(result.vectors == Iref)
pass
def test_top_k(selfs):
pass
def get_dataset(d, nb, nt, nq):
"""A dataset that is not completely random but still challenging to
index
"""
d1 = 10 # intrinsic dimension (more or less)
n = nb + nt + nq
rs = np.random.RandomState(1338)
x = rs.normal(size=(n, d1))
x = np.dot(x, rs.rand(d1, d))
# now we have a d1-dim ellipsoid in d-dimensional space
# higher factor (>4) -> higher frequency -> less linear
x = x * (rs.rand(d) * 4 + 0.1)
x = np.sin(x)
x = x.astype('float32')
return x[:nt], x[nt:-nq], x[-nq:]
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册