diff --git a/pyengine/engine/controller/scheduler.py b/pyengine/engine/controller/scheduler.py index 64ede17877679fa8c4173997a9810a097c5a997b..48b708113371078ccf9d40f30c49cbde142c0db5 100644 --- a/pyengine/engine/controller/scheduler.py +++ b/pyengine/engine/controller/scheduler.py @@ -42,7 +42,7 @@ class Scheduler(metaclass=Singleton): result_list.append(searcher.search_by_vectors(vectors, k)) if len(result_list) == 1: - return result_list[0].vectors + return result_list[0].vectors[0].tolist() # TODO(linxj): fix hard code return result_list; # TODO(linxj): add topk diff --git a/pyengine/engine/controller/tests/test_scheduler.py b/pyengine/engine/controller/tests/test_scheduler.py index 1f69dbd94c0eef4d02034e135e0252088fc6e1bb..d3bd2a4e95a1e4fdc1b038da6b21d253e25de7cc 100644 --- a/pyengine/engine/controller/tests/test_scheduler.py +++ b/pyengine/engine/controller/tests/test_scheduler.py @@ -6,10 +6,10 @@ import numpy as np class TestScheduler(unittest.TestCase): - def test_schedule(self): + def test_single_query(self): d = 64 nb = 10000 - nq = 2 + nq = 1 nt = 5000 xt, xb, xq = get_dataset(d, nb, nt, nq) ids_xb = np.arange(xb.shape[0]) diff --git a/pyengine/engine/controller/vector_engine.py b/pyengine/engine/controller/vector_engine.py index 879a6915b9dbbc552ed9793f647c4c7c67fe6af5..ed434d3322291d5cf4c5eaf6a0dde004862372a0 100644 --- a/pyengine/engine/controller/vector_engine.py +++ b/pyengine/engine/controller/vector_engine.py @@ -159,9 +159,8 @@ class VectorEngine(object): vectors.append(vector) result = scheduler_instance.search(index_map, vectors, limit) - vector_id = [0] vector_ids_str = [] - for int_id in vector_id: + for int_id in result: vector_ids_str.append(group_id + '.' + str(int_id)) return VectorEngine.SUCCESS_CODE, vector_ids_str @@ -201,7 +200,7 @@ class VectorEngine(object): VectorEngine.group_vector_dict[group_id].append(vector) VectorEngine.group_vector_id_dict[group_id].append(vector_id) - print('InsertVectorIntoRawFile: ', VectorEngine.group_vector_dict[group_id], VectorEngine.group_vector_id_dict[group_id]) + # print('InsertVectorIntoRawFile: ', VectorEngine.group_vector_dict[group_id], VectorEngine.group_vector_id_dict[group_id]) print("cache size: ", len(VectorEngine.group_vector_dict[group_id])) return filename @@ -209,8 +208,8 @@ class VectorEngine(object): @staticmethod def GetVectorListFromRawFile(group_id, filename="todo"): - print("GetVectorListFromRawFile, vectors: ", serialize.to_array(VectorEngine.group_vector_dict[group_id])) - print("GetVectorListFromRawFile, vector_ids: ", serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id])) + # print("GetVectorListFromRawFile, vectors: ", serialize.to_array(VectorEngine.group_vector_dict[group_id])) + # print("GetVectorListFromRawFile, vector_ids: ", serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id])) return serialize.to_array(VectorEngine.group_vector_dict[group_id]), serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id]) @staticmethod diff --git a/pyengine/tests/test_function.py b/pyengine/tests/test_function.py index af90871750bf5bed4234b3b35e0492e506ca54fc..bef005722f05f6b20b7c112d9e8d55449c9bbcc9 100644 --- a/pyengine/tests/test_function.py +++ b/pyengine/tests/test_function.py @@ -1,6 +1,6 @@ -import unittest import numpy as np import requests +import pytest import logging import json @@ -8,34 +8,38 @@ url = "http://127.0.0.1:5000" # TODO: LOG and Assert -class TestEngineFunction(unittest.TestCase): +class TestEngineFunction(): def test_1m_add(self): d = 4 - nb = 120 + nb = 100 nq = 1 k = 10 _, xb, xq = get_dataset(d, nb, 1, nq) - groupid = "5m" + groupid = "test_search_3" - # route_group = url + "/vector/group/" + groupid - # r = requests.post(route_group, json={"dimension": d}) - # - # # import dataset - # vector_add_route = url + "/vector/add/" + groupid - # for i in xb: - # data = dict() - # data['vector'] = i.tolist() - # r = requests.post(vector_add_route, json=data) + route_group = url + "/vector/group/" + groupid + r = requests.post(route_group, json={"dimension": d}) + + # import dataset + vector_add_route = url + "/vector/add/" + groupid + for i in xb: + data = dict() + data['vector'] = i.tolist() + # print(data) + r = requests.post(vector_add_route, json=data) + print(r.json()) # search dataset vector_search_route = url + "/vector/search/" + groupid data = dict() - data['vector'] = xq.tolist() - data['limit'] = k - r = requests.get(vector_search_route, json=data) + for i in xq: + data['vector'] = i.tolist() + data['limit'] = k + # print(data) + r = requests.get(vector_search_route, json=data) + print(r.json()) - print("finish") def get_dataset(d, nb, nt, nq): d1 = 10 # intrinsic dimension (more or less) @@ -47,7 +51,3 @@ def get_dataset(d, nb, nt, nq): x = np.sin(x) x = x.astype('float32') return x[:nt], x[nt:-nq], x[-nq:] - - -if __name__ == "__main__": - unittest.main()