提交 1bebe578 编写于 作者: X xj.lin

Merge branch 'develop' into linxj

......@@ -13,4 +13,4 @@ db = SQLAlchemy(app)
from engine.model.group_table import GroupTable
from engine.model.file_table import FileTable
from engine.controller import index_manager
from engine.controller import views
......@@ -10,8 +10,7 @@ class GroupHandler(object):
path=path.rstrip("\\")
if not os.path.exists(path):
os.makedirs(path)
print("CreateGroupDirectory, Path: ", path)
return path
@staticmethod
def DeleteGroupDirectory(group_id):
......@@ -20,7 +19,7 @@ class GroupHandler(object):
path=path.rstrip("\\")
if os.path.exists(path):
shutil.rmtree(path)
print("DeleteGroupDirectory, Path: ", path)
return path
@staticmethod
def GetGroupDirectory(group_id):
......
import pytest
from flask import Flask
from engine import app
@pytest.fixture(scope='module')
def test_client():
# Flask provides a way to test your application by exposing the Werkzeug test Client
# and handling the context locals for you.
testing_client = app.test_client()
# Establish an application context before running the tests.
ctx = app.app_context()
ctx.push()
yield testing_client # this is where the testing happens!
ctx.pop()
\ No newline at end of file
from engine.controller.group_handler import GroupHandler
from engine.settings import DATABASE_DIRECTORY
import pytest
import os
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class TestGroupHandler:
def test_get_group(self):
group_path = GroupHandler.GetGroupDirectory('test_group')
verified_path = DATABASE_DIRECTORY + '/' + 'test_group'
logger.debug(group_path)
assert group_path == verified_path
def test_create_group(self):
group_path = GroupHandler.CreateGroupDirectory('test_group')
if os.path.exists(group_path):
assert True
else:
assert False
def test_delete_group(self):
group_path = GroupHandler.GetGroupDirectory('test_group')
if os.path.exists(group_path):
assert True
GroupHandler.DeleteGroupDirectory('test_group')
if os.path.exists(group_path):
assert False
else:
assert True
else:
assert False
from engine.controller.vector_engine import VectorEngine
from engine.settings import DATABASE_DIRECTORY
from flask import jsonify
import pytest
import os
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class TestVectorEngine:
def setup_class(self):
self.__vector = [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]
self.__limit = 3
def teardown_class(self):
pass
def test_group(self):
# Make sure there is no group
code, group_id, file_number = VectorEngine.DeleteGroup('test_group')
assert code == VectorEngine.SUCCESS_CODE
assert group_id == 'test_group'
assert file_number == 0
# Add a group
code, group_id, file_number = VectorEngine.AddGroup('test_group', 8)
assert code == VectorEngine.SUCCESS_CODE
assert group_id == 'test_group'
assert file_number == 0
# Check the group existing
code, group_id, file_number = VectorEngine.GetGroup('test_group')
assert code == VectorEngine.SUCCESS_CODE
assert group_id == 'test_group'
assert file_number == 0
# Check the group list
code, group_list = VectorEngine.GetGroupList()
assert code == VectorEngine.SUCCESS_CODE
assert group_list == [{'group_name': 'test_group', 'file_number': 0}]
# Add Vector for not exist group
code = VectorEngine.AddVector('not_exist_group', self.__vector)
assert code == VectorEngine.GROUP_NOT_EXIST
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
# Check search vector interface
code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 0
# Check create index interface
code = VectorEngine.CreateIndex('test_group')
assert code == VectorEngine.SUCCESS_CODE
# Remove the group
code, group_id, file_number = VectorEngine.DeleteGroup('test_group')
assert code == VectorEngine.SUCCESS_CODE
assert group_id == 'test_group'
assert file_number == 0
# Check the group is disppeared
code, group_id, file_number = VectorEngine.GetGroup('test_group')
assert code == VectorEngine.FAULT_CODE
assert group_id == 'test_group'
assert file_number == 0
# Check SearchVector interface
code = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
assert code == VectorEngine.GROUP_NOT_EXIST
# Create Index for not exist group id
code = VectorEngine.CreateIndex('test_group')
assert code == VectorEngine.GROUP_NOT_EXIST
# Clear raw file
code = VectorEngine.ClearRawFile('test_group')
assert code == VectorEngine.SUCCESS_CODE
def test_raw_file(self):
filename = VectorEngine.InsertVectorIntoRawFile('test_group', 'test_group.raw', self.__vector)
assert filename == 'test_group.raw'
expected_list = [self.__vector]
vector_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
print('expected_list: ', expected_list)
print('vector_list: ', vector_list)
assert vector_list == expected_list
code = VectorEngine.ClearRawFile('test_group')
assert code == VectorEngine.SUCCESS_CODE
from engine.controller.vector_engine import VectorEngine
from engine.settings import DATABASE_DIRECTORY
from engine import app
from flask import jsonify
import pytest
import os
import logging
import json
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class TestViews:
def loads(self, resp):
return json.loads(resp.data.decode())
def test_group(self, test_client):
data = {"dimension": 10}
resp = test_client.get('/vector/group/6')
assert resp.status_code == 200
assert self.loads(resp)['code'] == 1
resp = test_client.post('/vector/group/6', data=json.dumps(data))
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
resp = test_client.get('/vector/group/6')
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
# GroupList
resp = test_client.get('/vector/group')
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
assert self.loads(resp)['group_list'] == [{'file_number': 0, 'group_name': '6'}]
resp = test_client.delete('/vector/group/6')
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
def test_vector(self, test_client):
dimension = {"dimension": 10}
resp = test_client.post('/vector/group/6', data=json.dumps(dimension))
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
vector = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]}
resp = test_client.post('/vector/add/6', data=json.dumps(vector))
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
resp = test_client.post('/vector/index/6')
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
limit = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], "limit": 3}
resp = test_client.get('/vector/search/6', data=json.dumps(limit))
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
assert self.loads(resp)['vector_id'] == 0
resp = test_client.delete('/vector/group/6')
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
......@@ -12,13 +12,17 @@ import sys, os
class VectorEngine(object):
group_dict = None
SUCCESS_CODE = 0
FAULT_CODE = 1
GROUP_NOT_EXIST = 2
@staticmethod
def AddGroup(group_id, dimension):
group = GroupTable.query.filter(GroupTable.group_name==group_id).first()
if group:
print('Already create the group: ', group_id)
return jsonify({'code': 1, 'group_name': group_id, 'file_number': group.file_number})
return VectorEngine.FAULT_CODE, group_id, group.file_number
# return jsonify({'code': 1, 'group_name': group_id, 'file_number': group.file_number})
else:
print('To create the group: ', group_id)
new_group = GroupTable(group_id, dimension)
......@@ -27,18 +31,16 @@ class VectorEngine(object):
# add into database
db.session.add(new_group)
db.session.commit()
return jsonify({'code': 0, 'group_name': group_id, 'file_number': 0})
return VectorEngine.SUCCESS_CODE, group_id, 0
@staticmethod
def GetGroup(group_id):
group = GroupTable.query.filter(GroupTable.group_name==group_id).first()
if group:
print('Found the group: ', group_id)
return jsonify({'code': 0, 'group_name': group_id, 'file_number': group.file_number})
return VectorEngine.SUCCESS_CODE, group_id, group.file_number
else:
print('Not found the group: ', group_id)
return jsonify({'code': 1, 'group_name': group_id, 'file_number': 0}) # not found
return VectorEngine.FAULT_CODE, group_id, 0
@staticmethod
......@@ -56,9 +58,9 @@ class VectorEngine(object):
db.session.delete(record)
db.session.commit()
return jsonify({'code': 0, 'group_name': group_id, 'file_number': group.file_number})
return VectorEngine.SUCCESS_CODE, group_id, group.file_number
else:
return jsonify({'code': 0, 'group_name': group_id, 'file_number': 0})
return VectorEngine.SUCCESS_CODE, group_id, 0
@staticmethod
......@@ -72,12 +74,16 @@ class VectorEngine(object):
group_list.append(group_item)
print(group_list)
return jsonify(results = group_list)
return VectorEngine.SUCCESS_CODE, group_list
@staticmethod
def AddVector(group_id, vector):
print(group_id, vector)
code, _, _ = VectorEngine.GetGroup(group_id)
if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST
file = FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').first()
group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
if file:
......@@ -118,37 +124,44 @@ class VectorEngine(object):
db.session.add(FileTable(group_id, raw_filename, 'raw', 1))
db.session.commit()
return jsonify({'code': 0})
return VectorEngine.SUCCESS_CODE
@staticmethod
def SearchVector(group_id, vectors, limit):
def SearchVector(group_id, vector, limit):
# Check the group exist
code, _, _ = VectorEngine.GetGroup(group_id)
if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST
# find all files
files = FileTable.query.filter(FileTable.group_name == group_id).all()
raw_keys = [ i.filename for i in files if i.type == 'raw' ]
index_keys = [ i.filename for i in files if i.type == 'index' ]
index_map = dict
index_map = {}
index_map['raw'] = raw_keys
index_map['index'] = index_keys # {raw:[key1, key2], index:[key3, key4]}
scheduler_instance = Scheduler
result = scheduler_instance.Search(index_map, vectors, k=limit)
scheduler_instance = Scheduler()
result = scheduler_instance.Search(index_map, vector, limit)
# according to difference files get topk of each
# reduce the topk from them
# construct response and send back
return jsonify({'code': 0})
vector_id = 0
return VectorEngine.SUCCESS_CODE, vector_id
# TODO(linxj): Debug Interface. UnSopport now
# @staticmethod
# def CreateIndex(group_id):
# # create index
# file = FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').first()
# path = GroupHandler.GetGroupDirectory(group_id) + '/' + file.filename
# print('Going to create index for: ', path)
# return jsonify({'code': 0})
@staticmethod
def CreateIndex(group_id):
# Check the group exist
code, _, _ = VectorEngine.GetGroup(group_id)
if code == VectorEngine.FAULT_CODE:
return VectorEngine.GROUP_NOT_EXIST
# create index
file = FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').first()
path = GroupHandler.GetGroupDirectory(group_id) + '/' + file.filename
print('Going to create index for: ', path)
return VectorEngine.SUCCESS_CODE
@staticmethod
......@@ -158,17 +171,13 @@ class VectorEngine(object):
if VectorEngine.group_dict is None:
# print("VectorEngine.group_dict is None")
VectorEngine.group_dict = dict()
if not (group_id in VectorEngine.group_dict):
VectorEngine.group_dict[group_id] = []
VectorEngine.group_dict[group_id].append(vector)
print('InsertVectorIntoRawFile: ', VectorEngine.group_dict[group_id])
# if filename exist
# append
# if filename not exist
# create file
# append
return filename
......@@ -176,3 +185,9 @@ class VectorEngine(object):
def GetVectorListFromRawFile(group_id, filename="todo"):
return VectorEngine.group_dict[group_id]
@staticmethod
def ClearRawFile(group_id):
print("VectorEngine.group_dict: ", VectorEngine.group_dict)
del VectorEngine.group_dict[group_id]
return VectorEngine.SUCCESS_CODE
......@@ -18,7 +18,8 @@ class Vector(Resource):
def post(self, group_id):
args = self.__parser.parse_args()
vector = args['vector']
return VectorEngine.AddVector(group_id, vector)
code = VectorEngine.AddVector(group_id, vector)
return jsonify({'code': code})
class VectorSearch(Resource):
......@@ -27,11 +28,12 @@ class VectorSearch(Resource):
self.__parser.add_argument('vector', type=float, action='append', location=['json'])
self.__parser.add_argument('limit', type=int, action='append', location=['json'])
def post(self, group_id):
def get(self, group_id):
args = self.__parser.parse_args()
print('vector: ', args['vector'])
# go to search every thing
return VectorEngine.SearchVector(group_id, args['vector'], args['limit'])
code, vector_id = VectorEngine.SearchVector(group_id, args['vector'], args['limit'])
return jsonify({'code': code, 'vector_id': vector_id})
class Index(Resource):
......@@ -40,7 +42,8 @@ class Index(Resource):
# self.__parser.add_argument('group_id', type=str)
def post(self, group_id):
return VectorEngine.CreateIndex(group_id)
code = VectorEngine.CreateIndex(group_id)
return jsonify({'code': code})
class Group(Resource):
......@@ -52,18 +55,22 @@ class Group(Resource):
def post(self, group_id):
args = self.__parser.parse_args()
dimension = args['dimension']
return VectorEngine.AddGroup(group_id, dimension)
code, group_id, file_number = VectorEngine.AddGroup(group_id, dimension)
return jsonify({'code': code, 'group': group_id, 'filenumber': file_number})
def get(self, group_id):
return VectorEngine.GetGroup(group_id)
code, group_id, file_number = VectorEngine.GetGroup(group_id)
return jsonify({'code': code, 'group': group_id, 'filenumber': file_number})
def delete(self, group_id):
return VectorEngine.DeleteGroup(group_id)
code, group_id, file_number = VectorEngine.DeleteGroup(group_id)
return jsonify({'code': code, 'group': group_id, 'filenumber': file_number})
class GroupList(Resource):
def get(self):
return VectorEngine.GetGroupList()
code, group_list = VectorEngine.GetGroupList()
return jsonify({'code': code, 'group_list': group_list})
api.add_resource(Vector, '/vector/add/<group_id>')
......
......@@ -53,45 +53,51 @@
import numpy as np
import pytest
d = 64 # dimension
nb = 100000 # database size
nq = 10000 # nb of queries
np.random.seed(1234) # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xb[:, 0] += np.arange(nb) / 1000.
xc = np.random.random((nb, d)).astype('float32')
xc[:, 0] += np.arange(nb) / 1000.
xq = np.random.random((nq, d)).astype('float32')
xq[:, 0] += np.arange(nq) / 1000.
@pytest.mark.skip(reason="Not for pytest")
def basic_test():
d = 64 # dimension
nb = 100000 # database size
nq = 10000 # nb of queries
np.random.seed(1234) # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xb[:, 0] += np.arange(nb) / 1000.
xc = np.random.random((nb, d)).astype('float32')
xc[:, 0] += np.arange(nb) / 1000.
xq = np.random.random((nq, d)).astype('float32')
xq[:, 0] += np.arange(nq) / 1000.
import faiss # make faiss available
index = faiss.IndexFlatL2(d) # build the index
print(index.is_trained)
index.add(xb) # add vectors to the index
print(index.ntotal)
#faiss.write_index(index, "/tmp/faiss/tempfile_1")
import faiss # make faiss available
index = faiss.IndexFlatL2(d) # build the index
print(index.is_trained)
index.add(xb) # add vectors to the index
print(index.ntotal)
#faiss.write_index(index, "/tmp/faiss/tempfile_1")
writer = faiss.VectorIOWriter()
faiss.write_index(index, writer)
ar_data = faiss.vector_to_array(writer.data)
import pickle
pickle.dump(ar_data, open("/tmp/faiss/ser_1", "wb"))
writer = faiss.VectorIOWriter()
faiss.write_index(index, writer)
ar_data = faiss.vector_to_array(writer.data)
import pickle
pickle.dump(ar_data, open("/tmp/faiss/ser_1", "wb"))
index_3 = pickle.load("/tmp/faiss/ser_1")
#index_3 = pickle.load("/tmp/faiss/ser_1")
# index_2 = faiss.IndexFlatL2(d) # build the index
# print(index_2.is_trained)
# index_2.add(xc) # add vectors to the index
# print(index_2.ntotal)
# faiss.write_index(index, "/tmp/faiss/tempfile_2")
#
# index_3 = faiss.read_index
# index_2 = faiss.IndexFlatL2(d) # build the index
# print(index_2.is_trained)
# index_2.add(xc) # add vectors to the index
# print(index_2.ntotal)
# faiss.write_index(index, "/tmp/faiss/tempfile_2")
#
# index_3 = faiss.read_index
# k = 4 # we want to see 4 nearest neighbors
# D, I = index.search(xb[:5], k) # sanity check
# print(I)
# print(D)
# D, I = index.search(xq, k) # actual search
# print(I[:5]) # neighbors of the 5 first queries
# print(I[-5:]) # neighbors of the 5 last queries
\ No newline at end of file
# k = 4 # we want to see 4 nearest neighbors
# D, I = index.search(xb[:5], k) # sanity check
# print(I)
# print(D)
# D, I = index.search(xq, k) # actual search
# print(I[:5]) # neighbors of the 5 first queries
# print(I[-5:]) # neighbors of the 5 last queries
if __name__ == '__main__':
basic_test()
from engine.controller import scheduler
# scheduler.Scheduler.Search()
\ No newline at end of file
pytest -v --disable-warnings
import os
import faiss
class StorageManager(object):
def __init__():
pass
def put(vector, directory, index_type):
pass
def take(dir):
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册