提交 5616ec74 编写于 作者: X xj.lin

fix much bug

上级 925af7e1
from engine.retrieval import search_index
from engine.ingestion import build_index
from engine.ingestion import serialize
import numpy as np
class Singleton(type):
_instances = {}
......@@ -22,25 +23,33 @@ class Scheduler(metaclass=Singleton):
def __scheduler(self, index_data_key, vectors, k):
result_list = []
d = None
raw_vectors = None
print("__scheduler: vectors: ", vectors)
query_vectors = np.asarray(vectors).astype('float32')
if 'raw' in index_data_key:
raw_vectors = index_data_key['raw']
raw_vectors = np.asarray(raw_vectors).astype('float32')
d = index_data_key['dimension']
if 'raw' in index_data_key:
index_builder = build_index.FactoryIndex()
print("d: ", d, " raw_vectors: ", raw_vectors)
index = index_builder().build(d, raw_vectors)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k))
result_list.append(searcher.search_by_vectors(query_vectors, k))
index_data_list = index_data_key['index']
for key in index_data_list:
index = GetIndexData(key)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k))
result_list.append(searcher.search_by_vectors(query_vectors, k))
if len(result_list) == 1:
return result_list[0].vectors
total_result = []
# result = search_index.top_k(result_list, k)
return result_list
......
......@@ -11,7 +11,9 @@ logger = logging.getLogger(__name__)
class TestVectorEngine:
def setup_class(self):
self.__vector = [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]
self.__limit = 3
self.__vector_2 = [1.2, 2.2, 3.3, 4.5, 5.5, 6.6, 7.8, 8.8]
self.__query_vector = [[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8],[1.2, 2.2, 3.3, 4.5, 5.5, 6.6, 7.8, 8.8]]
self.__limit = 1
def teardown_class(self):
......@@ -39,6 +41,7 @@ class TestVectorEngine:
# Check the group list
code, group_list = VectorEngine.GetGroupList()
assert code == VectorEngine.SUCCESS_CODE
print("group_list: ", group_list)
assert group_list == [{'group_name': 'test_group', 'file_number': 0}]
# Add Vector for not exist group
......@@ -46,11 +49,23 @@ class TestVectorEngine:
assert code == VectorEngine.GROUP_NOT_EXIST
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector)
code = VectorEngine.AddVector('test_group', self.__vector_2)
assert code == VectorEngine.SUCCESS_CODE
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector_2)
assert code == VectorEngine.SUCCESS_CODE
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector_2)
assert code == VectorEngine.SUCCESS_CODE
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector_2)
assert code == VectorEngine.SUCCESS_CODE
# Check search vector interface
code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
code, vector_id = VectorEngine.SearchVector('test_group', self.__query_vector, self.__limit)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 0
......
......@@ -52,12 +52,22 @@ class TestViews:
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
vector = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]}
resp = test_client.post('/vector/add/6', data=json.dumps(vector))
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
vector = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]}
resp = test_client.post('/vector/add/6', data=json.dumps(vector))
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
resp = test_client.post('/vector/index/6')
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
limit = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], "limit": 3}
resp = test_client.get('/vector/search/6', data=json.dumps(limit))
limit = {"vector": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], "limit": 1}
resp = test_client.post('/vector/search/6', data=json.dumps(limit))
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0
assert self.loads(resp)['vector_id'] == 0
......
......@@ -94,7 +94,7 @@ class VectorEngine(object):
# check if the file can be indexed
if file.row_number + 1 >= ROW_LIMIT:
raw_data = GetVectorListFromRawFile(group_id)
raw_data = VectorEngine.GetVectorListFromRawFile(group_id)
d = group.dimension
# create index
......@@ -144,7 +144,7 @@ class VectorEngine(object):
index_keys = [ i.filename for i in files if i.type == 'index' ]
index_map = {}
index_map['index'] = index_keys
index_map['raw'] = GetVectorListFromRawFile(group_id)
index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename")
index_map['dimension'] = group.dimension
scheduler_instance = Scheduler()
......@@ -189,6 +189,7 @@ class VectorEngine(object):
@staticmethod
def GetVectorListFromRawFile(group_id, filename="todo"):
return VectorEngine.group_dict[group_id]
# return serialize.to_array(VectorEngine.group_dict[group_id])
@staticmethod
def ClearRawFile(group_id):
......
......@@ -28,9 +28,9 @@ class VectorSearch(Resource):
self.__parser.add_argument('vector', type=float, action='append', location=['json'])
self.__parser.add_argument('limit', type=int, action='append', location=['json'])
def get(self, group_id):
def post(self, group_id):
args = self.__parser.parse_args()
print('vector: ', args['vector'])
print('VectorSearch vector: ', args['vector'])
# go to search every thing
code, vector_id = VectorEngine.SearchVector(group_id, args['vector'], args['limit'])
return jsonify({'code': code, 'vector_id': vector_id})
......
......@@ -15,7 +15,7 @@ def FactoryIndex(index_name="DefaultIndex"):
class Index():
def build(d, vectors, DEVICE=INDEX_DEVICES.CPU):
def build(self, d, vectors, DEVICE=INDEX_DEVICES.CPU):
pass
@staticmethod
......
import faiss
import numpy as np
def write_index(index, file_name):
faiss.write_index(index, file_name)
def read_index(file_name):
return faiss.read_index(file_name)
\ No newline at end of file
return faiss.read_index(file_name)
def to_array(vec):
return np.asarray(vec).astype('float32')
\ No newline at end of file
......@@ -12,7 +12,6 @@ class GroupTable(db.Model):
self.group_name = group_name
self.dimension = dimension
self.file_number = 0
self.dimension = 0
def __repr__(self):
......
......@@ -6,4 +6,4 @@ SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_DATABASE_URI = "mysql+pymysql://vecwise@127.0.0.1:3306/vecdata"
ROW_LIMIT = 10000000
DATABASE_DIRECTORY = '/home/jinhai/disk0/vecwise/db'
\ No newline at end of file
DATABASE_DIRECTORY = '/tmp'
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册