From 855d1c613d1e0b66424fe44ed1dab9024c1d54c5 Mon Sep 17 00:00:00 2001 From: "xj.lin" Date: Sun, 24 Mar 2019 20:36:28 +0800 Subject: [PATCH] add unittest for build/search index --- pyengine/engine/ingestion/tests/test_build.py | 33 +++++++++++-- .../engine/retrieval/tests/scheduler_test.py | 3 -- .../engine/retrieval/tests/test_search.py | 48 +++++++++++++++++++ 3 files changed, 77 insertions(+), 7 deletions(-) delete mode 100644 pyengine/engine/retrieval/tests/scheduler_test.py create mode 100644 pyengine/engine/retrieval/tests/test_search.py diff --git a/pyengine/engine/ingestion/tests/test_build.py b/pyengine/engine/ingestion/tests/test_build.py index 4dc90e95..ef349f90 100644 --- a/pyengine/engine/ingestion/tests/test_build.py +++ b/pyengine/engine/ingestion/tests/test_build.py @@ -7,7 +7,9 @@ import unittest class TestBuildIndex(unittest.TestCase): def test_factory_method(self): - pass + index_builder = FactoryIndex() + index = index_builder() + self.assertIsInstance(index, DefaultIndex) def test_default_index(self): d = 64 @@ -30,15 +32,38 @@ class TestBuildIndex(unittest.TestCase): d = 64 nb = 10000 nq = 100 - _, xb, xq = get_dataset(d, nb, 500, nq) + nt = 500 + xt, xb, xq = get_dataset(d, nb, nt, nq) index = faiss.IndexFlatL2(d) index.add(xb) - pass + assert index.ntotal == nb + + Index.increase(index, xt) + assert index.ntotal == nb + nt def test_serialize(self): - pass + d = 64 + nb = 10000 + nq = 100 + nt = 500 + xt, xb, xq = get_dataset(d, nb, nt, nq) + + index = faiss.IndexFlatL2(d) + index.add(xb) + Dref, Iref = index.search(xq, 5) + + ar_data = Index.serialize(index) + + reader = faiss.VectorIOReader() + faiss.copy_array_to_vector(ar_data, reader.data) + index2 = faiss.read_index(reader) + + Dnew, Inew = index2.search(xq, 5) + + assert np.all(Dnew == Dref) and np.all(Inew == Iref) + def get_dataset(d, nb, nt, nq): diff --git a/pyengine/engine/retrieval/tests/scheduler_test.py b/pyengine/engine/retrieval/tests/scheduler_test.py deleted file mode 100644 index 71dc6e7b..00000000 --- a/pyengine/engine/retrieval/tests/scheduler_test.py +++ /dev/null @@ -1,3 +0,0 @@ -from engine.controller import scheduler - -scheduler.Scheduler.Search() \ No newline at end of file diff --git a/pyengine/engine/retrieval/tests/test_search.py b/pyengine/engine/retrieval/tests/test_search.py new file mode 100644 index 00000000..cd3ed927 --- /dev/null +++ b/pyengine/engine/retrieval/tests/test_search.py @@ -0,0 +1,48 @@ +from ..search_index import * + +import unittest +import numpy as np + +class TestSearchSingleThread(unittest.TestCase): + def test_search_by_vectors(self): + d = 64 + nb = 10000 + nq = 100 + _, xb, xq = get_dataset(d, nb, 500, nq) + + index = faiss.IndexFlatL2(d) + index.add(xb) + + # expect result + Dref, Iref = index.search(xq, 5) + + searcher = FaissSearch(index) + result = searcher.search_by_vectors(xq, 5) + + assert np.all(result.distance == Dref) \ + and np.all(result.vectors == Iref) + pass + + def test_top_k(selfs): + pass + + +def get_dataset(d, nb, nt, nq): + """A dataset that is not completely random but still challenging to + index + """ + d1 = 10 # intrinsic dimension (more or less) + n = nb + nt + nq + rs = np.random.RandomState(1338) + x = rs.normal(size=(n, d1)) + x = np.dot(x, rs.rand(d1, d)) + # now we have a d1-dim ellipsoid in d-dimensional space + # higher factor (>4) -> higher frequency -> less linear + x = x * (rs.rand(d) * 4 + 0.1) + x = np.sin(x) + x = x.astype('float32') + return x[:nt], x[nt:-nq], x[-nq:] + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file -- GitLab