提交 675b5317 编写于 作者: X xj.lin

Merge branch 'develop' into linxj

......@@ -27,9 +27,10 @@ class Scheduler(metaclass=Singleton):
if 'raw' in index_data_key:
raw_vectors = index_data_key['raw']
raw_vector_ids = index_data_key['raw_id']
d = index_data_key['dimension']
index_builder = build_index.FactoryIndex()
index = index_builder().build(d, raw_vectors)
index = index_builder().build(d, raw_vectors, raw_vector_ids)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k))
......@@ -43,7 +44,7 @@ class Scheduler(metaclass=Singleton):
if len(result_list) == 1:
return result_list[0].vectors
return result_list; # TODO(linxj): add topk
return result_list; # TODO(linxj): add topk
# d_list = np.array([])
# v_list = np.array([])
......
......@@ -12,31 +12,41 @@ class TestScheduler(unittest.TestCase):
nq = 2
nt = 5000
xt, xb, xq = get_dataset(d, nb, nt, nq)
ids_xb = np.arange(xb.shape[0])
ids_xt = np.arange(xt.shape[0])
file_name = "/tmp/tempfile_1"
index = faiss.IndexFlatL2(d)
print(index.is_trained)
index.add(xb)
faiss.write_index(index, file_name)
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(xb, ids_xb)
Dref, Iref = index.search(xq, 5)
faiss.write_index(index, file_name)
index2 = faiss.read_index(file_name)
scheduler_instance = Scheduler()
schuduler_instance = Scheduler()
# query 1
query_index = dict()
query_index['index'] = [file_name]
vectors = scheduler_instance.search(query_index, vectors=xq, k=5)
assert np.all(vectors == Iref)
# query args 1
# query_index = dict()
# query_index['index'] = [file_name]
# vectors = schuduler_instance.search(query_index, vectors=xq, k=5)
# assert np.all(vectors == Iref)
# query 2
query_index.clear()
query_index['raw'] = xb
query_index['raw_id'] = ids_xb
query_index['dimension'] = d
vectors = scheduler_instance.search(query_index, vectors=xq, k=5)
assert np.all(vectors == Iref)
# query args 2
# query_index = dict()
# query 3
# TODO(linxj): continue...
# query_index.clear()
# query_index['raw'] = xt
# query_index['raw_id'] = ids_xt
# query_index['dimension'] = d
# query_index['index'] = [file_name]
# vectors = schuduler_instance.search(query_index, vectors=xq, k=5)
# print("success")
# vectors = scheduler_instance.search(query_index, vectors=xq, k=5)
# assert np.all(vectors == Iref)
def get_dataset(d, nb, nt, nq):
......
......@@ -44,29 +44,34 @@ class TestVectorEngine:
assert group_list == [{'group_name': 'test_group', 'file_number': 0}]
# Add Vector for not exist group
code = VectorEngine.AddVector('not_exist_group', self.__vector)
code, vector_id = VectorEngine.AddVector('not_exist_group', self.__vector)
assert code == VectorEngine.GROUP_NOT_EXIST
assert vector_id == 'invalid'
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.0'
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.1'
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.2'
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
code, vector_id = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 'test_group.3'
# Check search vector interface
code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 0
assert vector_id == ['test_group.0']
# Check create index interface
code = VectorEngine.CreateIndex('test_group')
......@@ -85,8 +90,9 @@ class TestVectorEngine:
assert file_number == 0
# Check SearchVector interface
code = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
code, vector_ids = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
assert code == VectorEngine.GROUP_NOT_EXIST
assert vector_ids == {}
# Create Index for not exist group id
code = VectorEngine.CreateIndex('test_group')
......@@ -97,17 +103,18 @@ class TestVectorEngine:
assert code == VectorEngine.SUCCESS_CODE
def test_raw_file(self):
filename = VectorEngine.InsertVectorIntoRawFile('test_group', 'test_group.raw', self.__vector)
filename = VectorEngine.InsertVectorIntoRawFile('test_group', 'test_group.raw', self.__vector, 0)
assert filename == 'test_group.raw'
expected_list = [self.__vector]
vector_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
vector_list, vector_id_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
print('expected_list: ', expected_list)
print('vector_list: ', vector_list)
expected_list = np.asarray(expected_list).astype('float32')
print('vector_id_list: ', vector_id_list)
expected_list = np.asarray(expected_list).astype('float32')
assert np.all(vector_list == expected_list)
code = VectorEngine.ClearRawFile('test_group')
......
......@@ -71,7 +71,7 @@ class TestViews:
resp = test_client.get('/vector/search/6', data=json.dumps(limit), headers = TestViews.HEADERS)
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
assert self.loads(resp)['vector_id'] == 0
assert self.loads(resp)['vector_id'] == ['6.0']
resp = test_client.delete('/vector/group/6', headers = TestViews.HEADERS)
assert resp.status_code == 200
......
......@@ -12,7 +12,8 @@ from engine.ingestion import serialize
import sys, os
class VectorEngine(object):
group_dict = None
group_vector_dict = None
group_vector_id_dict = None
SUCCESS_CODE = 0
FAULT_CODE = 1
GROUP_NOT_EXIST = 2
......@@ -83,23 +84,25 @@ class VectorEngine(object):
print(group_id, vector)
code, _, _ = VectorEngine.GetGroup(group_id)
if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST
return VectorEngine.GROUP_NOT_EXIST, 'invalid'
file = FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').first()
group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
if file:
print('insert into exist file')
# create vector id
vector_id = file.seq_no + 1
# insert into raw file
VectorEngine.InsertVectorIntoRawFile(group_id, file.filename, vector)
VectorEngine.InsertVectorIntoRawFile(group_id, file.filename, vector, vector_id)
# check if the file can be indexed
if file.row_number + 1 >= ROW_LIMIT:
raw_data = VectorEngine.GetVectorListFromRawFile(group_id)
raw_vector_array, raw_vector_id_array = VectorEngine.GetVectorListFromRawFile(group_id)
d = group.dimension
# create index
index_builder = build_index.FactoryIndex()
index = index_builder().build(d, raw_data)
index = index_builder().build(d, raw_vector_array, raw_vector_id_array)
# TODO(jinhai): store index into Cache
index_filename = file.filename + '_index'
......@@ -107,13 +110,14 @@ class VectorEngine(object):
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
'type': 'index',
'filename': index_filename})
'filename': index_filename,
'seq_no': file.seq_no + 1})
db.session.commit()
VectorEngine.group_dict = None
else:
# we still can insert into exist raw file, update database
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1})
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
'seq_no': file.seq_no + 1})
db.session.commit()
print('Update db for raw file insertion')
pass
......@@ -122,13 +126,16 @@ class VectorEngine(object):
print('add a new raw file')
# first raw file
raw_filename = group_id + '.raw'
# create vector id
vector_id = 0
# create and insert vector into raw file
VectorEngine.InsertVectorIntoRawFile(group_id, raw_filename, vector)
VectorEngine.InsertVectorIntoRawFile(group_id, raw_filename, vector, vector_id)
# insert a record into database
db.session.add(FileTable(group_id, raw_filename, 'raw', 1))
db.session.commit()
return VectorEngine.SUCCESS_CODE
vector_id_str = group_id + '.' + str(vector_id)
return VectorEngine.SUCCESS_CODE, vector_id_str
@staticmethod
......@@ -136,16 +143,15 @@ class VectorEngine(object):
# Check the group exist
code, _, _ = VectorEngine.GetGroup(group_id)
if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST
return VectorEngine.GROUP_NOT_EXIST, {}
group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
# find all files
files = FileTable.query.filter(FileTable.group_name == group_id).all()
index_keys = [ i.filename for i in files if i.type == 'index' ]
index_map = {}
index_map['index'] = index_keys
index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage
index_map['raw'], index_map['raw_id'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage
index_map['dimension'] = group.dimension
scheduler_instance = Scheduler()
......@@ -153,9 +159,12 @@ class VectorEngine(object):
vectors.append(vector)
result = scheduler_instance.search(index_map, vectors, limit)
# vector_id = 0
vector_id = [0]
vector_ids_str = []
for int_id in vector_id:
vector_ids_str.append(group_id + '.' + str(int_id))
return VectorEngine.SUCCESS_CODE, result
return VectorEngine.SUCCESS_CODE, vector_ids_str
@staticmethod
......@@ -173,30 +182,41 @@ class VectorEngine(object):
@staticmethod
def InsertVectorIntoRawFile(group_id, filename, vector):
def InsertVectorIntoRawFile(group_id, filename, vector, vector_id):
# print(sys._getframe().f_code.co_name, group_id, vector)
# path = GroupHandler.GetGroupDirectory(group_id) + '/' + filename
if VectorEngine.group_dict is None:
# print("VectorEngine.group_dict is None")
VectorEngine.group_dict = dict()
if VectorEngine.group_vector_dict is None:
# print("VectorEngine.group_vector_dict is None")
VectorEngine.group_vector_dict = dict()
if VectorEngine.group_vector_id_dict is None:
VectorEngine.group_vector_id_dict = dict()
if not (group_id in VectorEngine.group_vector_dict):
VectorEngine.group_vector_dict[group_id] = []
if not (group_id in VectorEngine.group_vector_id_dict):
VectorEngine.group_vector_id_dict[group_id] = []
if not (group_id in VectorEngine.group_dict):
VectorEngine.group_dict[group_id] = []
VectorEngine.group_vector_dict[group_id].append(vector)
VectorEngine.group_vector_id_dict[group_id].append(vector_id)
VectorEngine.group_dict[group_id].append(vector)
print('InsertVectorIntoRawFile: ', VectorEngine.group_vector_dict[group_id], VectorEngine.group_vector_id_dict[group_id])
print("cache size: ", len(VectorEngine.group_vector_dict[group_id]))
# print('InsertVectorIntoRawFile: ', VectorEngine.group_dict[group_id])
print("cache size: ", len(VectorEngine.group_dict[group_id]))
return filename
@staticmethod
def GetVectorListFromRawFile(group_id, filename="todo"):
return serialize.to_array(VectorEngine.group_dict[group_id])
print("GetVectorListFromRawFile, vectors: ", serialize.to_array(VectorEngine.group_vector_dict[group_id]))
print("GetVectorListFromRawFile, vector_ids: ", serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id]))
return serialize.to_array(VectorEngine.group_vector_dict[group_id]), serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id])
@staticmethod
def ClearRawFile(group_id):
print("VectorEngine.group_dict: ", VectorEngine.group_dict)
del VectorEngine.group_dict[group_id]
print("VectorEngine.group_vector_dict: ", VectorEngine.group_vector_dict)
del VectorEngine.group_vector_dict[group_id]
del VectorEngine.group_vector_id_dict[group_id]
return VectorEngine.SUCCESS_CODE
......@@ -3,6 +3,7 @@ from flask_restful import Resource, Api
from engine import app, db
from engine.model.group_table import GroupTable
from engine.controller.vector_engine import VectorEngine
import json
# app = Flask(__name__)
api = Api(app)
......@@ -19,8 +20,8 @@ class Vector(Resource):
print(request.json)
args = self.__parser.parse_args()
vector = args['vector']
code = VectorEngine.AddVector(group_id, vector)
return jsonify({'code': code})
code, vector_id = VectorEngine.AddVector(group_id, vector)
return jsonify({'code': code, 'vector_id': vector_id})
class VectorSearch(Resource):
......@@ -35,7 +36,9 @@ class VectorSearch(Resource):
print('limit: ', args['limit'])
# go to search every thing
code, vector_id = VectorEngine.SearchVector(group_id, args['vector'], args['limit'])
print('vector_id: ', vector_id)
return jsonify({'code': code, 'vector_id': vector_id})
#return jsonify(})
class Index(Resource):
......
......@@ -15,12 +15,12 @@ def FactoryIndex(index_name="DefaultIndex"):
class Index():
def build(self, d, vectors, DEVICE=INDEXDEVICES.CPU):
def build(self, d, vectors, vector_ids, DEVICE=INDEXDEVICES.CPU):
pass
@staticmethod
def increase(trained_index, vectors):
trained_index.add((vectors))
trained_index.add_with_ids(vectors. vector_ids)
@staticmethod
def serialize(index):
......@@ -35,10 +35,11 @@ class DefaultIndex(Index):
# maybe need to specif parameters
pass
def build(self, d, vectors, DEVICE=INDEXDEVICES.CPU):
index = faiss.IndexFlatL2(d) # trained
index.add(vectors)
return index
def build(self, d, vectors, vector_ids, DEVICE=INDEXDEVICES.CPU):
index = faiss.IndexFlatL2(d)
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(vectors, vector_ids)
return index2
class LowMemoryIndex(Index):
......@@ -47,7 +48,7 @@ class LowMemoryIndex(Index):
self.__bytes_per_vector = 8
self.__bits_per_sub_vector = 8
def build(self, d, vectors, DEVICE=INDEXDEVICES.CPU):
def build(d, vectors, vector_ids, DEVICE=INDEXDEVICES.CPU):
# quantizer = faiss.IndexFlatL2(d)
# index = faiss.IndexIVFPQ(quantizer, d, self.nlist,
# self.__bytes_per_vector, self.__bits_per_sub_vector)
......
......@@ -12,3 +12,7 @@ def read_index(file_name):
def to_array(vec):
return np.asarray(vec).astype('float32')
def to_int_array(vec):
return np.asarray(vec).astype('int64')
......@@ -16,32 +16,35 @@ class TestBuildIndex(unittest.TestCase):
nb = 10000
nq = 100
_, xb, xq = get_dataset(d, nb, 500, nq)
ids = np.arange(xb.shape[0])
# Expected result
index = faiss.IndexFlatL2(d)
index.add(xb)
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(xb, ids)
Dref, Iref = index.search(xq, 5)
builder = DefaultIndex()
index2 = builder.build(d, xb)
index2 = builder.build(d, xb, ids)
Dnew, Inew = index2.search(xq, 5)
assert np.all(Dnew == Dref) and np.all(Inew == Iref)
def test_increase(self):
d = 64
nb = 10000
nq = 100
nt = 500
xt, xb, xq = get_dataset(d, nb, nt, nq)
index = faiss.IndexFlatL2(d)
index.add(xb)
assert index.ntotal == nb
Index.increase(index, xt)
assert index.ntotal == nb + nt
# d = 64
# nb = 10000
# nq = 100
# nt = 500
# xt, xb, xq = get_dataset(d, nb, nt, nq)
#
# index = faiss.IndexFlatL2(d)
# index.add(xb)
#
# assert index.ntotal == nb
#
# Index.increase(index, xt)
# assert index.ntotal == nb + nt
pass
def test_serialize(self):
d = 64
......
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
aniso8601=6.0.0=py_0
asn1crypto=0.24.0=py36_0
atomicwrites=1.3.0=py36_1
attrs=19.1.0=py36_1
blas=1.0=mkl
ca-certificates=2019.1.23=0
certifi=2019.3.9=py36_0
cffi=1.12.2=py36h2e261b9_1
chardet=3.0.4=py36_1
click=7.0=py36_0
conda=4.6.8=py36_0
conda-env=2.6.0=1
cryptography=2.6.1=py36h1ba5d50_0
cuda92=1.0=0
faiss-cpu=1.5.0=py36_1
faiss-gpu=1.5.0=py36_cuda9.2_1
flask=1.0.2=py36_1
flask-restful=0.3.6=py_1
flask-sqlalchemy=2.3.2=py36_0
idna=2.8=py36_0
intel-openmp=2019.1=144
itsdangerous=1.1.0=py36_0
jinja2=2.10=py36_0
libedit=3.1.20170329=h6b74fdf_2
libffi=3.2.1=hd88cf55_4
libgcc-ng=8.2.0=hdf63c60_1
libgfortran-ng=7.3.0=hdf63c60_0
libstdcxx-ng=8.2.0=hdf63c60_1
markupsafe=1.1.1=py36h7b6447c_0
mkl=2019.1=144
mkl_fft=1.0.10=py36ha843d7b_0
mkl_random=1.0.2=py36hd81dba3_0
more-itertools=6.0.0=py36_0
ncurses=6.1=he6710b0_1
numpy=1.16.2=py36h7e9f1db_0
numpy-base=1.16.2=py36hde5b4d6_0
openssl=1.1.1b=h7b6447c_1
pip=19.0.3=py36_0
pluggy=0.9.0=py36_0
py=1.8.0=py36_0
pycosat=0.6.3=py36h14c3975_0
pycparser=2.19=py36_0
pycrypto=2.6.1=py36h14c3975_9
pymysql=0.9.3=py36_0
pyopenssl=19.0.0=py36_0
pysocks=1.6.8=py36_0
pytest=4.3.1=py36_0
python=3.6.8=h0371630_0
python-dateutil=2.8.0=py_0
pytz=2018.9=py_0
readline=7.0=h7b6447c_5
requests=2.21.0=py36_0
ruamel_yaml=0.15.46=py36h14c3975_0
setuptools=40.8.0=py36_0
six=1.12.0=py36_0
sqlalchemy=1.3.1=py36h7b6447c_0
sqlite=3.26.0=h7b6447c_0
tk=8.6.8=hbc83047_0
urllib3=1.24.1=py36_0
werkzeug=0.14.1=py36_0
wheel=0.33.1=py36_0
xz=5.2.4=h14c3975_4
yaml=0.1.7=had09818_2
zlib=1.2.11=h7b6447c_3
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册