提交 374bd2f6 编写于 作者: X xiaojun.lin

Merge branch 'linxj' into 'develop'

Linxj

See merge request jinhai/vecwise_engine!12
...@@ -42,7 +42,7 @@ class Scheduler(metaclass=Singleton): ...@@ -42,7 +42,7 @@ class Scheduler(metaclass=Singleton):
result_list.append(searcher.search_by_vectors(vectors, k)) result_list.append(searcher.search_by_vectors(vectors, k))
if len(result_list) == 1: if len(result_list) == 1:
return result_list[0].vectors return result_list[0].vectors[0].tolist() # TODO(linxj): fix hard code
return result_list; # TODO(linxj): add topk return result_list; # TODO(linxj): add topk
......
...@@ -6,10 +6,10 @@ import numpy as np ...@@ -6,10 +6,10 @@ import numpy as np
class TestScheduler(unittest.TestCase): class TestScheduler(unittest.TestCase):
def test_schedule(self): def test_single_query(self):
d = 64 d = 64
nb = 10000 nb = 10000
nq = 2 nq = 1
nt = 5000 nt = 5000
xt, xb, xq = get_dataset(d, nb, nt, nq) xt, xb, xq = get_dataset(d, nb, nt, nq)
ids_xb = np.arange(xb.shape[0]) ids_xb = np.arange(xb.shape[0])
......
...@@ -112,8 +112,8 @@ class VectorEngine(object): ...@@ -112,8 +112,8 @@ class VectorEngine(object):
'type': 'index', 'type': 'index',
'filename': index_filename, 'filename': index_filename,
'seq_no': file.seq_no + 1}) 'seq_no': file.seq_no + 1})
pass db.session.commit()
VectorEngine.group_dict = None
else: else:
# we still can insert into exist raw file, update database # we still can insert into exist raw file, update database
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1, FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
...@@ -159,9 +159,8 @@ class VectorEngine(object): ...@@ -159,9 +159,8 @@ class VectorEngine(object):
vectors.append(vector) vectors.append(vector)
result = scheduler_instance.search(index_map, vectors, limit) result = scheduler_instance.search(index_map, vectors, limit)
vector_id = [0]
vector_ids_str = [] vector_ids_str = []
for int_id in vector_id: for int_id in result:
vector_ids_str.append(group_id + '.' + str(int_id)) vector_ids_str.append(group_id + '.' + str(int_id))
return VectorEngine.SUCCESS_CODE, vector_ids_str return VectorEngine.SUCCESS_CODE, vector_ids_str
...@@ -201,14 +200,16 @@ class VectorEngine(object): ...@@ -201,14 +200,16 @@ class VectorEngine(object):
VectorEngine.group_vector_dict[group_id].append(vector) VectorEngine.group_vector_dict[group_id].append(vector)
VectorEngine.group_vector_id_dict[group_id].append(vector_id) VectorEngine.group_vector_id_dict[group_id].append(vector_id)
print('InsertVectorIntoRawFile: ', VectorEngine.group_vector_dict[group_id], VectorEngine.group_vector_id_dict[group_id]) # print('InsertVectorIntoRawFile: ', VectorEngine.group_vector_dict[group_id], VectorEngine.group_vector_id_dict[group_id])
print("cache size: ", len(VectorEngine.group_vector_dict[group_id]))
return filename return filename
@staticmethod @staticmethod
def GetVectorListFromRawFile(group_id, filename="todo"): def GetVectorListFromRawFile(group_id, filename="todo"):
print("GetVectorListFromRawFile, vectors: ", serialize.to_array(VectorEngine.group_vector_dict[group_id])) # print("GetVectorListFromRawFile, vectors: ", serialize.to_array(VectorEngine.group_vector_dict[group_id]))
print("GetVectorListFromRawFile, vector_ids: ", serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id])) # print("GetVectorListFromRawFile, vector_ids: ", serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id]))
return serialize.to_array(VectorEngine.group_vector_dict[group_id]), serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id]) return serialize.to_array(VectorEngine.group_vector_dict[group_id]), serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id])
@staticmethod @staticmethod
......
...@@ -17,6 +17,7 @@ class Vector(Resource): ...@@ -17,6 +17,7 @@ class Vector(Resource):
self.__parser.add_argument('vector', type=float, action='append', location=['json']) self.__parser.add_argument('vector', type=float, action='append', location=['json'])
def post(self, group_id): def post(self, group_id):
print(request.json)
args = self.__parser.parse_args() args = self.__parser.parse_args()
vector = args['vector'] vector = args['vector']
code, vector_id = VectorEngine.AddVector(group_id, vector) code, vector_id = VectorEngine.AddVector(group_id, vector)
......
import numpy as np
import requests
import pytest
import logging
import json
url = "http://127.0.0.1:5000"
# TODO: LOG and Assert
class TestEngineFunction():
def test_1m_add(self):
d = 4
nb = 100
nq = 1
k = 10
_, xb, xq = get_dataset(d, nb, 1, nq)
groupid = "test_search_3"
route_group = url + "/vector/group/" + groupid
r = requests.post(route_group, json={"dimension": d})
# import dataset
vector_add_route = url + "/vector/add/" + groupid
for i in xb:
data = dict()
data['vector'] = i.tolist()
# print(data)
r = requests.post(vector_add_route, json=data)
print(r.json())
# search dataset
vector_search_route = url + "/vector/search/" + groupid
data = dict()
for i in xq:
data['vector'] = i.tolist()
data['limit'] = k
# print(data)
r = requests.get(vector_search_route, json=data)
print(r.json())
def get_dataset(d, nb, nt, nq):
d1 = 10 # intrinsic dimension (more or less)
n = nb + nt + nq
rs = np.random.RandomState(1338)
x = rs.normal(size=(n, d1))
x = np.dot(x, rs.rand(d1, d))
x = x * (rs.rand(d) * 4 + 0.1)
x = np.sin(x)
x = x.astype('float32')
return x[:nt], x[nt:-nq], x[-nq:]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册