提交 6a927405 编写于 作者: J jinhai

Merge branch 'jinhai' into 'develop'

Jinhai

See merge request jinhai/vecwise_engine!11
......@@ -27,9 +27,10 @@ class Scheduler(metaclass=Singleton):
if 'raw' in index_data_key:
raw_vectors = index_data_key['raw']
raw_vector_ids = index_data_key['raw_id']
d = index_data_key['dimension']
index_builder = build_index.FactoryIndex()
index = index_builder().build(d, raw_vectors)
index = index_builder().build(d, raw_vectors, raw_vector_ids)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k))
......@@ -43,7 +44,7 @@ class Scheduler(metaclass=Singleton):
if len(result_list) == 1:
return result_list[0].vectors
return result_list; # TODO(linxj): add topk
return result_list; # TODO(linxj): add topk
# d_list = np.array([])
# v_list = np.array([])
......
......@@ -12,31 +12,41 @@ class TestScheduler(unittest.TestCase):
nq = 2
nt = 5000
xt, xb, xq = get_dataset(d, nb, nt, nq)
ids_xb = np.arange(xb.shape[0])
ids_xt = np.arange(xt.shape[0])
file_name = "/tmp/tempfile_1"
index = faiss.IndexFlatL2(d)
print(index.is_trained)
index.add(xb)
faiss.write_index(index, file_name)
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(xb, ids_xb)
Dref, Iref = index.search(xq, 5)
faiss.write_index(index, file_name)
index2 = faiss.read_index(file_name)
scheduler_instance = Scheduler()
schuduler_instance = Scheduler()
# query 1
query_index = dict()
query_index['index'] = [file_name]
vectors = scheduler_instance.search(query_index, vectors=xq, k=5)
assert np.all(vectors == Iref)
# query args 1
# query_index = dict()
# query_index['index'] = [file_name]
# vectors = schuduler_instance.search(query_index, vectors=xq, k=5)
# assert np.all(vectors == Iref)
# query 2
query_index.clear()
query_index['raw'] = xb
query_index['raw_id'] = ids_xb
query_index['dimension'] = d
vectors = scheduler_instance.search(query_index, vectors=xq, k=5)
assert np.all(vectors == Iref)
# query args 2
# query_index = dict()
# query 3
# TODO(linxj): continue...
# query_index.clear()
# query_index['raw'] = xt
# query_index['raw_id'] = ids_xt
# query_index['dimension'] = d
# query_index['index'] = [file_name]
# vectors = schuduler_instance.search(query_index, vectors=xq, k=5)
# print("success")
# vectors = scheduler_instance.search(query_index, vectors=xq, k=5)
# assert np.all(vectors == Iref)
def get_dataset(d, nb, nt, nq):
......
......@@ -44,29 +44,34 @@ class TestVectorEngine:
assert group_list == [{'group_name': 'test_group', 'file_number': 0}]
# Add Vector for not exist group
code = VectorEngine.AddVector('not_exist_group', self.__vector)
code, vector_id = VectorEngine.AddVector('not_exist_group', self.__vector)
assert code == VectorEngine.GROUP_NOT_EXIST
assert vector_id == 'invalid'
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.0'
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.1'
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.2'
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.3'
# Check search vector interface
code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 0
assert vector_id == ['test_group.0']
# Check create index interface
code = VectorEngine.CreateIndex('test_group')
......@@ -85,8 +90,9 @@ class TestVectorEngine:
assert file_number == 0
# Check SearchVector interface
code = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
code, vector_ids = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
assert code == VectorEngine.GROUP_NOT_EXIST
assert vector_ids == {}
# Create Index for not exist group id
code = VectorEngine.CreateIndex('test_group')
......@@ -97,17 +103,18 @@ class TestVectorEngine:
assert code == VectorEngine.SUCCESS_CODE
def test_raw_file(self):
filename = VectorEngine.InsertVectorIntoRawFile('test_group', 'test_group.raw', self.__vector)
filename = VectorEngine.InsertVectorIntoRawFile('test_group', 'test_group.raw', self.__vector, 0)
assert filename == 'test_group.raw'
expected_list = [self.__vector]
vector_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
vector_list, vector_id_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
print('expected_list: ', expected_list)
print('vector_list: ', vector_list)
expected_list = np.asarray(expected_list).astype('float32')
print('vector_id_list: ', vector_id_list)
expected_list = np.asarray(expected_list).astype('float32')
assert np.all(vector_list == expected_list)
code = VectorEngine.ClearRawFile('test_group')
......
......@@ -71,7 +71,7 @@ class TestViews:
resp = test_client.get('/vector/search/6', data=json.dumps(limit), headers = TestViews.HEADERS)
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
assert self.loads(resp)['vector_id'] == 0
assert self.loads(resp)['vector_id'] == ['6.0']
resp = test_client.delete('/vector/group/6', headers = TestViews.HEADERS)
assert resp.status_code == 200
......
......@@ -12,7 +12,8 @@ from engine.ingestion import serialize
import sys, os
class VectorEngine(object):
group_dict = None
group_vector_dict = None
group_vector_id_dict = None
SUCCESS_CODE = 0
FAULT_CODE = 1
GROUP_NOT_EXIST = 2
......@@ -83,23 +84,25 @@ class VectorEngine(object):
print(group_id, vector)
code, _, _ = VectorEngine.GetGroup(group_id)
if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST
return VectorEngine.GROUP_NOT_EXIST, 'invalid'
file = FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').first()
group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
if file:
print('insert into exist file')
# create vector id
vector_id = file.seq_no + 1
# insert into raw file
VectorEngine.InsertVectorIntoRawFile(group_id, file.filename, vector)
VectorEngine.InsertVectorIntoRawFile(group_id, file.filename, vector, vector_id)
# check if the file can be indexed
if file.row_number + 1 >= ROW_LIMIT:
raw_data = VectorEngine.GetVectorListFromRawFile(group_id)
raw_vector_array, raw_vector_id_array = VectorEngine.GetVectorListFromRawFile(group_id)
d = group.dimension
# create index
index_builder = build_index.FactoryIndex()
index = index_builder().build(d, raw_data)
index = index_builder().build(d, raw_vector_array, raw_vector_id_array)
# TODO(jinhai): store index into Cache
index_filename = file.filename + '_index'
......@@ -107,12 +110,14 @@ class VectorEngine(object):
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
'type': 'index',
'filename': index_filename})
'filename': index_filename,
'seq_no': file.seq_no + 1})
pass
else:
# we still can insert into exist raw file, update database
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1})
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
'seq_no': file.seq_no + 1})
db.session.commit()
print('Update db for raw file insertion')
pass
......@@ -121,13 +126,16 @@ class VectorEngine(object):
print('add a new raw file')
# first raw file
raw_filename = group_id + '.raw'
# create vector id
vector_id = 0
# create and insert vector into raw file
VectorEngine.InsertVectorIntoRawFile(group_id, raw_filename, vector)
VectorEngine.InsertVectorIntoRawFile(group_id, raw_filename, vector, vector_id)
# insert a record into database
db.session.add(FileTable(group_id, raw_filename, 'raw', 1))
db.session.commit()
return VectorEngine.SUCCESS_CODE
vector_id_str = group_id + '.' + str(vector_id)
return VectorEngine.SUCCESS_CODE, vector_id_str
@staticmethod
......@@ -135,16 +143,15 @@ class VectorEngine(object):
# Check the group exist
code, _, _ = VectorEngine.GetGroup(group_id)
if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST
return VectorEngine.GROUP_NOT_EXIST, {}
group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
# find all files
files = FileTable.query.filter(FileTable.group_name == group_id).all()
index_keys = [ i.filename for i in files if i.type == 'index' ]
index_map = {}
index_map['index'] = index_keys
index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage
index_map['raw'], index_map['raw_id'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage
index_map['dimension'] = group.dimension
scheduler_instance = Scheduler()
......@@ -152,9 +159,12 @@ class VectorEngine(object):
vectors.append(vector)
result = scheduler_instance.search(index_map, vectors, limit)
vector_id = 0
vector_id = [0]
vector_ids_str = []
for int_id in vector_id:
vector_ids_str.append(group_id + '.' + str(int_id))
return VectorEngine.SUCCESS_CODE, vector_id
return VectorEngine.SUCCESS_CODE, vector_ids_str
@staticmethod
......@@ -172,29 +182,39 @@ class VectorEngine(object):
@staticmethod
def InsertVectorIntoRawFile(group_id, filename, vector):
def InsertVectorIntoRawFile(group_id, filename, vector, vector_id):
# print(sys._getframe().f_code.co_name, group_id, vector)
# path = GroupHandler.GetGroupDirectory(group_id) + '/' + filename
if VectorEngine.group_dict is None:
# print("VectorEngine.group_dict is None")
VectorEngine.group_dict = dict()
if VectorEngine.group_vector_dict is None:
# print("VectorEngine.group_vector_dict is None")
VectorEngine.group_vector_dict = dict()
if VectorEngine.group_vector_id_dict is None:
VectorEngine.group_vector_id_dict = dict()
if not (group_id in VectorEngine.group_vector_dict):
VectorEngine.group_vector_dict[group_id] = []
if not (group_id in VectorEngine.group_dict):
VectorEngine.group_dict[group_id] = []
if not (group_id in VectorEngine.group_vector_id_dict):
VectorEngine.group_vector_id_dict[group_id] = []
VectorEngine.group_dict[group_id].append(vector)
VectorEngine.group_vector_dict[group_id].append(vector)
VectorEngine.group_vector_id_dict[group_id].append(vector_id)
print('InsertVectorIntoRawFile: ', VectorEngine.group_dict[group_id])
print('InsertVectorIntoRawFile: ', VectorEngine.group_vector_dict[group_id], VectorEngine.group_vector_id_dict[group_id])
return filename
@staticmethod
def GetVectorListFromRawFile(group_id, filename="todo"):
return serialize.to_array(VectorEngine.group_dict[group_id])
print("GetVectorListFromRawFile, vectors: ", serialize.to_array(VectorEngine.group_vector_dict[group_id]))
print("GetVectorListFromRawFile, vector_ids: ", serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id]))
return serialize.to_array(VectorEngine.group_vector_dict[group_id]), serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id])
@staticmethod
def ClearRawFile(group_id):
print("VectorEngine.group_dict: ", VectorEngine.group_dict)
del VectorEngine.group_dict[group_id]
print("VectorEngine.group_vector_dict: ", VectorEngine.group_vector_dict)
del VectorEngine.group_vector_dict[group_id]
del VectorEngine.group_vector_id_dict[group_id]
return VectorEngine.SUCCESS_CODE
......@@ -3,6 +3,7 @@ from flask_restful import Resource, Api
from engine import app, db
from engine.model.group_table import GroupTable
from engine.controller.vector_engine import VectorEngine
import json
# app = Flask(__name__)
api = Api(app)
......@@ -18,8 +19,8 @@ class Vector(Resource):
def post(self, group_id):
args = self.__parser.parse_args()
vector = args['vector']
code = VectorEngine.AddVector(group_id, vector)
return jsonify({'code': code})
code, vector_id = VectorEngine.AddVector(group_id, vector)
return jsonify({'code': code, 'vector_id': vector_id})
class VectorSearch(Resource):
......@@ -34,7 +35,9 @@ class VectorSearch(Resource):
print('limit: ', args['limit'])
# go to search every thing
code, vector_id = VectorEngine.SearchVector(group_id, args['vector'], args['limit'])
print('vector_id: ', vector_id)
return jsonify({'code': code, 'vector_id': vector_id})
#return jsonify(})
class Index(Resource):
......
......@@ -15,12 +15,12 @@ def FactoryIndex(index_name="DefaultIndex"):
class Index():
def build(self, d, vectors, DEVICE=INDEXDEVICES.CPU):
def build(self, d, vectors, vector_ids, DEVICE=INDEXDEVICES.CPU):
pass
@staticmethod
def increase(trained_index, vectors):
trained_index.add((vectors))
trained_index.add_with_ids(vectors. vector_ids)
@staticmethod
def serialize(index):
......@@ -35,10 +35,11 @@ class DefaultIndex(Index):
# maybe need to specif parameters
pass
def build(self, d, vectors, DEVICE=INDEXDEVICES.CPU):
index = faiss.IndexFlatL2(d) # trained
index.add(vectors)
return index
def build(self, d, vectors, vector_ids, DEVICE=INDEXDEVICES.CPU):
index = faiss.IndexFlatL2(d)
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(vectors, vector_ids)
return index2
class LowMemoryIndex(Index):
......@@ -47,7 +48,7 @@ class LowMemoryIndex(Index):
self.__bytes_per_vector = 8
self.__bits_per_sub_vector = 8
def build(self, d, vectors, DEVICE=INDEXDEVICES.CPU):
def build(d, vectors, vector_ids, DEVICE=INDEXDEVICES.CPU):
# quantizer = faiss.IndexFlatL2(d)
# index = faiss.IndexIVFPQ(quantizer, d, self.nlist,
# self.__bytes_per_vector, self.__bits_per_sub_vector)
......
......@@ -12,3 +12,7 @@ def read_index(file_name):
def to_array(vec):
return np.asarray(vec).astype('float32')
def to_int_array(vec):
return np.asarray(vec).astype('int64')
......@@ -16,32 +16,35 @@ class TestBuildIndex(unittest.TestCase):
nb = 10000
nq = 100
_, xb, xq = get_dataset(d, nb, 500, nq)
ids = np.arange(xb.shape[0])
# Expected result
index = faiss.IndexFlatL2(d)
index.add(xb)
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(xb, ids)
Dref, Iref = index.search(xq, 5)
builder = DefaultIndex()
index2 = builder.build(d, xb)
index2 = builder.build(d, xb, ids)
Dnew, Inew = index2.search(xq, 5)
assert np.all(Dnew == Dref) and np.all(Inew == Iref)
def test_increase(self):
d = 64
nb = 10000
nq = 100
nt = 500
xt, xb, xq = get_dataset(d, nb, nt, nq)
index = faiss.IndexFlatL2(d)
index.add(xb)
assert index.ntotal == nb
Index.increase(index, xt)
assert index.ntotal == nb + nt
# d = 64
# nb = 10000
# nq = 100
# nt = 500
# xt, xb, xq = get_dataset(d, nb, nt, nq)
#
# index = faiss.IndexFlatL2(d)
# index.add(xb)
#
# assert index.ntotal == nb
#
# Index.increase(index, xt)
# assert index.ntotal == nb + nt
pass
def test_serialize(self):
d = 64
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册