test_scheduler.py 2.0 KB
Newer Older
X
xj.lin 已提交
1 2
from ..scheduler import *

X
xj.lin 已提交
3 4 5 6 7 8 9 10 11
import unittest
import faiss
import numpy as np


class TestScheduler(unittest.TestCase):
    def test_schedule(self):
        d = 64
        nb = 10000
X
xj.lin 已提交
12
        nq = 2
X
xj.lin 已提交
13 14
        nt = 5000
        xt, xb, xq = get_dataset(d, nb, nt, nq)
X
add id  
xj.lin 已提交
15 16
        ids_xb = np.arange(xb.shape[0])
        ids_xt = np.arange(xt.shape[0])
X
xj.lin 已提交
17
        file_name = "/tmp/tempfile_1"
X
xj.lin 已提交
18 19

        index = faiss.IndexFlatL2(d)
X
add id  
xj.lin 已提交
20 21
        index2 = faiss.IndexIDMap(index)
        index2.add_with_ids(xb, ids_xb)
X
xj.lin 已提交
22
        Dref, Iref = index.search(xq, 5)
X
add id  
xj.lin 已提交
23
        faiss.write_index(index, file_name)
X
xj.lin 已提交
24

J
jinhai 已提交
25
        scheduler_instance = Scheduler()
X
xj.lin 已提交
26

X
add id  
xj.lin 已提交
27
        # query 1
X
xj.lin 已提交
28 29
        query_index = dict()
        query_index['index'] = [file_name]
X
add id  
xj.lin 已提交
30
        vectors = scheduler_instance.search(query_index, vectors=xq, k=5)
X
xj.lin 已提交
31 32
        assert np.all(vectors == Iref)

X
add id  
xj.lin 已提交
33 34 35 36
        # query 2
        query_index.clear()
        query_index['raw'] = xb
        query_index['raw_id'] = ids_xb
X
xj.lin 已提交
37
        query_index['dimension'] = d
X
add id  
xj.lin 已提交
38 39
        vectors = scheduler_instance.search(query_index, vectors=xq, k=5)
        assert np.all(vectors == Iref)
X
xj.lin 已提交
40

X
add id  
xj.lin 已提交
41 42 43 44 45 46 47 48 49
        # query 3
        # TODO(linxj): continue...
        # query_index.clear()
        # query_index['raw'] = xt
        # query_index['raw_id'] = ids_xt
        # query_index['dimension'] = d
        # query_index['index'] = [file_name]
        # vectors = scheduler_instance.search(query_index, vectors=xq, k=5)
        # assert np.all(vectors == Iref)
X
xj.lin 已提交
50 51 52 53 54 55


def get_dataset(d, nb, nt, nq):
    """A dataset that is not completely random but still challenging to
    index
    """
X
xj.lin 已提交
56
    d1 = 10  # intrinsic dimension (more or less)
X
xj.lin 已提交
57 58 59 60 61 62 63 64 65 66 67
    n = nb + nt + nq
    rs = np.random.RandomState(1338)
    x = rs.normal(size=(n, d1))
    x = np.dot(x, rs.rand(d1, d))
    # now we have a d1-dim ellipsoid in d-dimensional space
    # higher factor (>4) -> higher frequency -> less linear
    x = x * (rs.rand(d) * 4 + 0.1)
    x = np.sin(x)
    x = x.astype('float32')
    return x[:nt], x[nt:-nq], x[-nq:]

X
xj.lin 已提交
68

X
xj.lin 已提交
69
if __name__ == "__main__":
X
xj.lin 已提交
70
    unittest.main()