提交 925af7e1 编写于 作者: X xj.lin

add schuduler unittest and fix some bug

上级 da427877
from engine.retrieval import search_index
from engine.ingestion import build_index
from engine.ingestion import serialize
class Singleton(type):
_instances = {}
......@@ -11,9 +12,9 @@ class Singleton(type):
class Scheduler(metaclass=Singleton):
def Search(self, index_file_key, vectors, k):
assert index_file_key
assert vectors
assert k
# assert index_file_key
# assert vectors
# assert k
return self.__scheduler(index_file_key, vectors, k)
......@@ -21,30 +22,29 @@ class Scheduler(metaclass=Singleton):
def __scheduler(self, index_data_key, vectors, k):
result_list = []
raw_data_list = index_data_key['raw']
index_data_list = index_data_key['index']
if 'raw' in index_data_key:
raw_vectors = index_data_key['raw']
d = index_data_key['dimension']
for key in raw_data_list:
raw_data, d = self.GetRawData(key)
if 'raw' in index_data_key:
index_builder = build_index.FactoryIndex()
index = index_builder().build(d, raw_data)
searcher = search_index.FaissSearch(index) # silly
index = index_builder().build(d, raw_vectors)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k))
index_data_list = index_data_key['index']
for key in index_data_list:
index = self.GetIndexData(key)
index = GetIndexData(key)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k))
if len(result_list) == 1:
return result_list[0].vectors
result = search_index.top_k(sum(result_list), k)
return result
# result = search_index.top_k(result_list, k)
return result_list
def GetIndexData(self, key):
pass
def GetRawData(self, key):
pass
def GetIndexData(key):
return serialize.read_index(key)
\ No newline at end of file
import unittest
from ..scheduler import *
import unittest
import faiss
import numpy as np
class TestScheduler(unittest.TestCase):
def test_schedule(self):
d = 64
nb = 10000
nq = 100
nt = 5000
xt, xb, xq = get_dataset(d, nb, nt, nq)
file_name = "/tmp/faiss/tempfile_1"
index = faiss.IndexFlatL2(d)
print(index.is_trained)
index.add(xb)
faiss.write_index(index, file_name)
Dref, Iref = index.search(xq, 5)
index2 = faiss.read_index(file_name)
schuduler_instance = Scheduler()
# query args 1
query_index = dict()
query_index['index'] = [file_name]
vectors = schuduler_instance.Search(query_index, vectors=xq, k=5)
assert np.all(vectors == Iref)
# query args 2
query_index = dict()
query_index['raw'] = xt
query_index['dimension'] = d
query_index['index'] = [file_name]
vectors = schuduler_instance.Search(query_index, vectors=xq, k=5)
# print("success")
def get_dataset(d, nb, nt, nq):
"""A dataset that is not completely random but still challenging to
index
"""
d1 = 10 # intrinsic dimension (more or less)
n = nb + nt + nq
rs = np.random.RandomState(1338)
x = rs.normal(size=(n, d1))
x = np.dot(x, rs.rand(d1, d))
# now we have a d1-dim ellipsoid in d-dimensional space
# higher factor (>4) -> higher frequency -> less linear
x = x * (rs.rand(d) * 4 + 0.1)
x = np.sin(x)
x = x.astype('float32')
return x[:nt], x[nt:-nq], x[-nq:]
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
......@@ -8,6 +8,7 @@ from flask import jsonify
from engine import db
from engine.ingestion import build_index
from engine.controller.scheduler import Scheduler
from engine.ingestion import serialize
import sys, os
class VectorEngine(object):
......@@ -98,14 +99,15 @@ class VectorEngine(object):
# create index
index_builder = build_index.FactoryIndex()
index = index_builder().build(d, raw_data) # type: index
index = build_index.Index.serialize(index) # type: array
index = index_builder().build(d, raw_data)
# TODO(jinhai): store index into Cache
index_filename = file.filename + '_index'
serialize.write_index(file_name=index_filename, index=index)
# TODO(jinhai): Update raw_file_name => index_file_name
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1, 'type': 'index'})
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
'type': 'index',
'filename': index_filename})
pass
else:
......@@ -135,13 +137,15 @@ class VectorEngine(object):
if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST
group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
# find all files
files = FileTable.query.filter(FileTable.group_name == group_id).all()
raw_keys = [ i.filename for i in files if i.type == 'raw' ]
index_keys = [ i.filename for i in files if i.type == 'index' ]
index_map = {}
index_map['raw'] = raw_keys
index_map['index'] = index_keys # {raw:[key1, key2], index:[key3, key4]}
index_map['index'] = index_keys
index_map['raw'] = GetVectorListFromRawFile(group_id)
index_map['dimension'] = group.dimension
scheduler_instance = Scheduler()
result = scheduler_instance.Search(index_map, vector, limit)
......
import faiss
def write_index(index, file_name):
faiss.write_index(index, file_name)
def read_index(file_name):
return faiss.read_index(file_name)
\ No newline at end of file
......@@ -7,8 +7,9 @@ class SearchResult():
self.vectors = I
def __add__(self, other):
self.distance += other.distance
self.vectors += other.vectors
distance = self.distance + other.distance
vectors = self.vectors + other.vectors
return SearchResult(distance, vectors)
class FaissSearch():
......@@ -31,6 +32,7 @@ class FaissSearch():
D, I = self.__index.search(vector_list, k)
return SearchResult(D, I)
import heapq
def top_k(input, k):
#sorted = heapq.nsmallest(k, input, key=input.key)
pass
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册