提交 2542f67c 编写于 作者: J jinhai

Update interface, part 1

上级 def1503c
...@@ -26,9 +26,10 @@ class Scheduler(metaclass=Singleton): ...@@ -26,9 +26,10 @@ class Scheduler(metaclass=Singleton):
if 'raw' in index_data_key: if 'raw' in index_data_key:
raw_vectors = index_data_key['raw'] raw_vectors = index_data_key['raw']
raw_vector_ids = index_data_key['raw_id']
d = index_data_key['dimension'] d = index_data_key['dimension']
index_builder = build_index.FactoryIndex() index_builder = build_index.FactoryIndex()
index = index_builder().build(d, raw_vectors) index = index_builder().build(d, raw_vectors, raw_vector_ids)
searcher = search_index.FaissSearch(index) searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k)) result_list.append(searcher.search_by_vectors(vectors, k))
......
...@@ -23,20 +23,25 @@ class TestScheduler(unittest.TestCase): ...@@ -23,20 +23,25 @@ class TestScheduler(unittest.TestCase):
index2 = faiss.read_index(file_name) index2 = faiss.read_index(file_name)
schuduler_instance = Scheduler() scheduler_instance = Scheduler()
# query args 1 # query args 1
query_index = dict() query_index = dict()
query_index['index'] = [file_name] query_index['index'] = [file_name]
vectors = schuduler_instance.Search(query_index, vectors=xq, k=5) vectors = scheduler_instance.Search(query_index, vectors=xq, k=5)
assert np.all(vectors == Iref) assert np.all(vectors == Iref)
# query args 2 # query args 2
query_index = dict() query_index = dict()
query_index['raw'] = xt query_index['raw'] = xt
# Xiaojun TODO: 'raw_id' part
# query_index['raw_id'] =
query_index['dimension'] = d query_index['dimension'] = d
query_index['index'] = [file_name] query_index['index'] = [file_name]
vectors = schuduler_instance.Search(query_index, vectors=xq, k=5)
# Xiaojun TODO: once 'raw_id' part added, open below
# vectors = scheduler_instance.Search(query_index, vectors=xq, k=5)
# print("success") # print("success")
......
...@@ -44,29 +44,34 @@ class TestVectorEngine: ...@@ -44,29 +44,34 @@ class TestVectorEngine:
assert group_list == [{'group_name': 'test_group', 'file_number': 0}] assert group_list == [{'group_name': 'test_group', 'file_number': 0}]
# Add Vector for not exist group # Add Vector for not exist group
code = VectorEngine.AddVector('not_exist_group', self.__vector) code, vector_id = VectorEngine.AddVector('not_exist_group', self.__vector)
assert code == VectorEngine.GROUP_NOT_EXIST assert code == VectorEngine.GROUP_NOT_EXIST
assert vector_id == 'invalid'
# Add vector for exist group # Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector) code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.0'
# Add vector for exist group # Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector) code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.1'
# Add vector for exist group # Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector) code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.2'
# Add vector for exist group # Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector) code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.3'
# Check search vector interface # Check search vector interface
code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit) code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
assert code == VectorEngine.SUCCESS_CODE assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 0 assert vector_id == {0}
# Check create index interface # Check create index interface
code = VectorEngine.CreateIndex('test_group') code = VectorEngine.CreateIndex('test_group')
...@@ -85,8 +90,9 @@ class TestVectorEngine: ...@@ -85,8 +90,9 @@ class TestVectorEngine:
assert file_number == 0 assert file_number == 0
# Check SearchVector interface # Check SearchVector interface
code = VectorEngine.SearchVector('test_group', self.__vector, self.__limit) code, vector_ids = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
assert code == VectorEngine.GROUP_NOT_EXIST assert code == VectorEngine.GROUP_NOT_EXIST
assert vector_ids == {}
# Create Index for not exist group id # Create Index for not exist group id
code = VectorEngine.CreateIndex('test_group') code = VectorEngine.CreateIndex('test_group')
...@@ -97,17 +103,18 @@ class TestVectorEngine: ...@@ -97,17 +103,18 @@ class TestVectorEngine:
assert code == VectorEngine.SUCCESS_CODE assert code == VectorEngine.SUCCESS_CODE
def test_raw_file(self): def test_raw_file(self):
filename = VectorEngine.InsertVectorIntoRawFile('test_group', 'test_group.raw', self.__vector) filename = VectorEngine.InsertVectorIntoRawFile('test_group', 'test_group.raw', self.__vector, 'test_group.0')
assert filename == 'test_group.raw' assert filename == 'test_group.raw'
expected_list = [self.__vector] expected_list = [self.__vector]
vector_list = VectorEngine.GetVectorListFromRawFile('test_group', filename) vector_list, vector_id_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
print('expected_list: ', expected_list) print('expected_list: ', expected_list)
print('vector_list: ', vector_list) print('vector_list: ', vector_list)
expected_list = np.asarray(expected_list).astype('float32') print('vector_id_list: ', vector_id_list)
expected_list = np.asarray(expected_list).astype('float32')
assert np.all(vector_list == expected_list) assert np.all(vector_list == expected_list)
code = VectorEngine.ClearRawFile('test_group') code = VectorEngine.ClearRawFile('test_group')
......
...@@ -12,7 +12,8 @@ from engine.ingestion import serialize ...@@ -12,7 +12,8 @@ from engine.ingestion import serialize
import sys, os import sys, os
class VectorEngine(object): class VectorEngine(object):
group_dict = None group_vector_dict = None
group_vector_id_dict = None
SUCCESS_CODE = 0 SUCCESS_CODE = 0
FAULT_CODE = 1 FAULT_CODE = 1
GROUP_NOT_EXIST = 2 GROUP_NOT_EXIST = 2
...@@ -83,23 +84,25 @@ class VectorEngine(object): ...@@ -83,23 +84,25 @@ class VectorEngine(object):
print(group_id, vector) print(group_id, vector)
code, _, _ = VectorEngine.GetGroup(group_id) code, _, _ = VectorEngine.GetGroup(group_id)
if code == VectorEngine.FAULT_CODE: if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST return VectorEngine.GROUP_NOT_EXIST, 'invalid'
file = FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').first() file = FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').first()
group = GroupTable.query.filter(GroupTable.group_name == group_id).first() group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
if file: if file:
print('insert into exist file') print('insert into exist file')
# create vector id
vector_id = group_id + '.' + (str)(file.seq_no + 1)
# insert into raw file # insert into raw file
VectorEngine.InsertVectorIntoRawFile(group_id, file.filename, vector) VectorEngine.InsertVectorIntoRawFile(group_id, file.filename, vector, vector_id)
# check if the file can be indexed # check if the file can be indexed
if file.row_number + 1 >= ROW_LIMIT: if file.row_number + 1 >= ROW_LIMIT:
raw_data = VectorEngine.GetVectorListFromRawFile(group_id) raw_vector_array, raw_vector_id_array = VectorEngine.GetVectorListFromRawFile(group_id)
d = group.dimension d = group.dimension
# create index # create index
index_builder = build_index.FactoryIndex() index_builder = build_index.FactoryIndex()
index = index_builder().build(d, raw_data) index = index_builder().build(d, raw_vector_array, raw_vector_id_array)
# TODO(jinhai): store index into Cache # TODO(jinhai): store index into Cache
index_filename = file.filename + '_index' index_filename = file.filename + '_index'
...@@ -107,12 +110,14 @@ class VectorEngine(object): ...@@ -107,12 +110,14 @@ class VectorEngine(object):
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1, FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
'type': 'index', 'type': 'index',
'filename': index_filename}) 'filename': index_filename,
'seq_no': file.seq_no + 1})
pass pass
else: else:
# we still can insert into exist raw file, update database # we still can insert into exist raw file, update database
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1}) FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
'seq_no': file.seq_no + 1})
db.session.commit() db.session.commit()
print('Update db for raw file insertion') print('Update db for raw file insertion')
pass pass
...@@ -121,13 +126,15 @@ class VectorEngine(object): ...@@ -121,13 +126,15 @@ class VectorEngine(object):
print('add a new raw file') print('add a new raw file')
# first raw file # first raw file
raw_filename = group_id + '.raw' raw_filename = group_id + '.raw'
# create vector id
vector_id = group_id + '.' + (str)(0)
# create and insert vector into raw file # create and insert vector into raw file
VectorEngine.InsertVectorIntoRawFile(group_id, raw_filename, vector) VectorEngine.InsertVectorIntoRawFile(group_id, raw_filename, vector, vector_id)
# insert a record into database # insert a record into database
db.session.add(FileTable(group_id, raw_filename, 'raw', 1)) db.session.add(FileTable(group_id, raw_filename, 'raw', 1))
db.session.commit() db.session.commit()
return VectorEngine.SUCCESS_CODE return VectorEngine.SUCCESS_CODE, vector_id
@staticmethod @staticmethod
...@@ -135,7 +142,7 @@ class VectorEngine(object): ...@@ -135,7 +142,7 @@ class VectorEngine(object):
# Check the group exist # Check the group exist
code, _, _ = VectorEngine.GetGroup(group_id) code, _, _ = VectorEngine.GetGroup(group_id)
if code == VectorEngine.FAULT_CODE: if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST return VectorEngine.GROUP_NOT_EXIST, {}
group = GroupTable.query.filter(GroupTable.group_name == group_id).first() group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
...@@ -144,7 +151,7 @@ class VectorEngine(object): ...@@ -144,7 +151,7 @@ class VectorEngine(object):
index_keys = [ i.filename for i in files if i.type == 'index' ] index_keys = [ i.filename for i in files if i.type == 'index' ]
index_map = {} index_map = {}
index_map['index'] = index_keys index_map['index'] = index_keys
index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage index_map['raw'], index_map['raw_id'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage
index_map['dimension'] = group.dimension index_map['dimension'] = group.dimension
scheduler_instance = Scheduler() scheduler_instance = Scheduler()
...@@ -152,7 +159,7 @@ class VectorEngine(object): ...@@ -152,7 +159,7 @@ class VectorEngine(object):
vectors.append(vector) vectors.append(vector)
result = scheduler_instance.Search(index_map, vectors, limit) result = scheduler_instance.Search(index_map, vectors, limit)
vector_id = 0 vector_id = {0}
return VectorEngine.SUCCESS_CODE, vector_id return VectorEngine.SUCCESS_CODE, vector_id
...@@ -172,29 +179,37 @@ class VectorEngine(object): ...@@ -172,29 +179,37 @@ class VectorEngine(object):
@staticmethod @staticmethod
def InsertVectorIntoRawFile(group_id, filename, vector): def InsertVectorIntoRawFile(group_id, filename, vector, vector_id):
# print(sys._getframe().f_code.co_name, group_id, vector) # print(sys._getframe().f_code.co_name, group_id, vector)
# path = GroupHandler.GetGroupDirectory(group_id) + '/' + filename # path = GroupHandler.GetGroupDirectory(group_id) + '/' + filename
if VectorEngine.group_dict is None: if VectorEngine.group_vector_dict is None:
# print("VectorEngine.group_dict is None") # print("VectorEngine.group_vector_dict is None")
VectorEngine.group_dict = dict() VectorEngine.group_vector_dict = dict()
if VectorEngine.group_vector_id_dict is None:
VectorEngine.group_vector_id_dict = dict()
if not (group_id in VectorEngine.group_vector_dict):
VectorEngine.group_vector_dict[group_id] = []
if not (group_id in VectorEngine.group_dict): if not (group_id in VectorEngine.group_vector_id_dict):
VectorEngine.group_dict[group_id] = [] VectorEngine.group_vector_id_dict[group_id] = []
VectorEngine.group_dict[group_id].append(vector) VectorEngine.group_vector_dict[group_id].append(vector)
VectorEngine.group_vector_id_dict[group_id].append(vector_id)
print('InsertVectorIntoRawFile: ', VectorEngine.group_dict[group_id]) print('InsertVectorIntoRawFile: ', VectorEngine.group_vector_dict[group_id], VectorEngine.group_vector_id_dict[group_id])
return filename return filename
@staticmethod @staticmethod
def GetVectorListFromRawFile(group_id, filename="todo"): def GetVectorListFromRawFile(group_id, filename="todo"):
return serialize.to_array(VectorEngine.group_dict[group_id]) return serialize.to_array(VectorEngine.group_vector_dict[group_id]), serialize.to_str_array(VectorEngine.group_vector_id_dict[group_id])
@staticmethod @staticmethod
def ClearRawFile(group_id): def ClearRawFile(group_id):
print("VectorEngine.group_dict: ", VectorEngine.group_dict) print("VectorEngine.group_vector_dict: ", VectorEngine.group_vector_dict)
del VectorEngine.group_dict[group_id] del VectorEngine.group_vector_dict[group_id]
del VectorEngine.group_vector_id_dict[group_id]
return VectorEngine.SUCCESS_CODE return VectorEngine.SUCCESS_CODE
...@@ -18,8 +18,8 @@ class Vector(Resource): ...@@ -18,8 +18,8 @@ class Vector(Resource):
def post(self, group_id): def post(self, group_id):
args = self.__parser.parse_args() args = self.__parser.parse_args()
vector = args['vector'] vector = args['vector']
code = VectorEngine.AddVector(group_id, vector) code, vector_id = VectorEngine.AddVector(group_id, vector)
return jsonify({'code': code}) return jsonify({'code': code, 'vector_id': vector_id})
class VectorSearch(Resource): class VectorSearch(Resource):
......
...@@ -15,7 +15,7 @@ def FactoryIndex(index_name="DefaultIndex"): ...@@ -15,7 +15,7 @@ def FactoryIndex(index_name="DefaultIndex"):
class Index(): class Index():
def build(self, d, vectors, DEVICE=INDEX_DEVICES.CPU): def build(self, d, vectors, vector_ids, DEVICE=INDEX_DEVICES.CPU):
pass pass
@staticmethod @staticmethod
...@@ -35,7 +35,7 @@ class DefaultIndex(Index): ...@@ -35,7 +35,7 @@ class DefaultIndex(Index):
# maybe need to specif parameters # maybe need to specif parameters
pass pass
def build(self, d, vectors, DEVICE=INDEX_DEVICES.CPU): def build(self, d, vectors, vector_ids, DEVICE=INDEX_DEVICES.CPU):
index = faiss.IndexFlatL2(d) # trained index = faiss.IndexFlatL2(d) # trained
index.add(vectors) index.add(vectors)
return index return index
...@@ -47,7 +47,7 @@ class LowMemoryIndex(Index): ...@@ -47,7 +47,7 @@ class LowMemoryIndex(Index):
self.__bytes_per_vector = 8 self.__bytes_per_vector = 8
self.__bits_per_sub_vector = 8 self.__bits_per_sub_vector = 8
def build(d, vectors, DEVICE=INDEX_DEVICES.CPU): def build(d, vectors, vector_ids, DEVICE=INDEX_DEVICES.CPU):
# quantizer = faiss.IndexFlatL2(d) # quantizer = faiss.IndexFlatL2(d)
# index = faiss.IndexIVFPQ(quantizer, d, self.nlist, # index = faiss.IndexIVFPQ(quantizer, d, self.nlist,
# self.__bytes_per_vector, self.__bits_per_sub_vector) # self.__bytes_per_vector, self.__bits_per_sub_vector)
......
...@@ -9,3 +9,6 @@ def read_index(file_name): ...@@ -9,3 +9,6 @@ def read_index(file_name):
def to_array(vec): def to_array(vec):
return np.asarray(vec).astype('float32') return np.asarray(vec).astype('float32')
def to_str_array(vec):
return np.asarray(vec).astype('str')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册