test_build.py 2.2 KB
Newer Older
X
xj.lin 已提交
1 2 3 4 5 6 7 8 9
from ..build_index import *

import faiss
import numpy as np
import unittest


class TestBuildIndex(unittest.TestCase):
    def test_factory_method(self):
X
xj.lin 已提交
10 11 12
        index_builder = FactoryIndex()
        index = index_builder()
        self.assertIsInstance(index, DefaultIndex)
X
xj.lin 已提交
13 14 15 16 17 18

    def test_default_index(self):
        d = 64
        nb = 10000
        nq = 100
        _, xb, xq = get_dataset(d, nb, 500, nq)
X
add id  
xj.lin 已提交
19
        ids = np.arange(xb.shape[0])
X
xj.lin 已提交
20 21 22

        # Expected result
        index = faiss.IndexFlatL2(d)
X
add id  
xj.lin 已提交
23 24
        index2 = faiss.IndexIDMap(index)
        index2.add_with_ids(xb, ids)
X
xj.lin 已提交
25 26 27
        Dref, Iref = index.search(xq, 5)

        builder = DefaultIndex()
X
add id  
xj.lin 已提交
28
        index2 = builder.build(d, xb, ids)
X
xj.lin 已提交
29 30 31 32 33 34 35 36
        Dnew, Inew = index2.search(xq, 5)

        assert np.all(Dnew == Dref) and np.all(Inew == Iref)

    def test_increase(self):
        d = 64
        nb = 10000
        nq = 100
X
xj.lin 已提交
37 38
        nt = 500
        xt, xb, xq = get_dataset(d, nb, nt, nq)
X
xj.lin 已提交
39 40 41 42

        index = faiss.IndexFlatL2(d)
        index.add(xb)

X
xj.lin 已提交
43 44 45 46
        assert index.ntotal == nb

        Index.increase(index, xt)
        assert index.ntotal == nb + nt
X
xj.lin 已提交
47 48

    def test_serialize(self):
X
xj.lin 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
        d = 64
        nb = 10000
        nq = 100
        nt = 500
        xt, xb, xq = get_dataset(d, nb, nt, nq)

        index = faiss.IndexFlatL2(d)
        index.add(xb)
        Dref, Iref = index.search(xq, 5)

        ar_data = Index.serialize(index)

        reader = faiss.VectorIOReader()
        faiss.copy_array_to_vector(ar_data, reader.data)
        index2 = faiss.read_index(reader)

        Dnew, Inew = index2.search(xq, 5)

        assert np.all(Dnew == Dref) and np.all(Inew == Iref)

X
xj.lin 已提交
69 70 71 72 73

def get_dataset(d, nb, nt, nq):
    """A dataset that is not completely random but still challenging to
    index
    """
X
xj.lin 已提交
74
    d1 = 10  # intrinsic dimension (more or less)
X
xj.lin 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87
    n = nb + nt + nq
    rs = np.random.RandomState(1338)
    x = rs.normal(size=(n, d1))
    x = np.dot(x, rs.rand(d1, d))
    # now we have a d1-dim ellipsoid in d-dimensional space
    # higher factor (>4) -> higher frequency -> less linear
    x = x * (rs.rand(d) * 4 + 0.1)
    x = np.sin(x)
    x = x.astype('float32')
    return x[:nt], x[nt:-nq], x[-nq:]


if __name__ == "__main__":
X
xj.lin 已提交
88
    unittest.main()