scheduler.py 2.4 KB
Newer Older
X
xj.lin 已提交
1 2
from engine.retrieval import search_index
from engine.ingestion import build_index
X
xj.lin 已提交
3
from engine.ingestion import serialize
X
xj.lin 已提交
4
import numpy as np
X
xj.lin 已提交
5

X
xj.lin 已提交
6 7 8

class Singleton(type):
    _instances = {}
X
xj.lin 已提交
9

X
xj.lin 已提交
10 11 12 13 14 15 16
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]


class Scheduler(metaclass=Singleton):
X
xj.lin 已提交
17
    def search(self, index_file_key, vectors, k):
X
xj.lin 已提交
18 19
        # assert index_file_key
        # assert vectors
X
xj.lin 已提交
20 21 22 23
        assert k != 0

        query_vectors = serialize.to_array(vectors)
        return self.__scheduler(index_file_key, query_vectors, k)
X
xj.lin 已提交
24 25 26 27

    def __scheduler(self, index_data_key, vectors, k):
        result_list = []

X
xj.lin 已提交
28 29
        if 'raw' in index_data_key:
            raw_vectors = index_data_key['raw']
J
jinhai 已提交
30
            raw_vector_ids = index_data_key['raw_id']
X
xj.lin 已提交
31
            d = index_data_key['dimension']
X
xj.lin 已提交
32
            index_builder = build_index.FactoryIndex()
J
jinhai 已提交
33
            index = index_builder().build(d, raw_vectors, raw_vector_ids)
X
xj.lin 已提交
34
            searcher = search_index.FaissSearch(index)
X
xj.lin 已提交
35
            result_list.append(searcher.search_by_vectors(vectors, k))
X
xj.lin 已提交
36

X
xj.lin 已提交
37 38 39
        if 'index' in index_data_key:
            index_data_list = index_data_key['index']
            for key in index_data_list:
X
xj.lin 已提交
40
                index = get_index_data(key)
X
xj.lin 已提交
41 42
                searcher = search_index.FaissSearch(index)
                result_list.append(searcher.search_by_vectors(vectors, k))
X
xj.lin 已提交
43 44

        if len(result_list) == 1:
X
xj.lin 已提交
45
            return result_list[0].vectors[0].tolist() # TODO(linxj): fix hard code
X
xj.lin 已提交
46

X
xj.lin 已提交
47
        return result_list;  # TODO(linxj): add topk
X
xj.lin 已提交
48

X
xj.lin 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
        # d_list = np.array([])
        # v_list = np.array([])
        # for result in result_list:
        #     rd = result.distance
        #     rv = result.vectors
        #
        #     td_list = np.array([])
        #     tv_list = np.array([])
        #     for d, v in zip(rd, rv):
        #         td_list = np.append(td_list, d)
        #         tv_list = np.append(tv_list, v)
        #     d_list = np.add(d_list, td_list)
        #     v_list = np.add(v_list, td_list)
        #
        # print(d_list)
        # print(v_list)
        # result_map = [d_list, v_list]
        # top_k_result = search_index.top_k(result_map, k)
        # return top_k_result
X
xj.lin 已提交
68 69


X
xj.lin 已提交
70 71
def get_index_data(key):
    return serialize.read_index(key)