提交 dbfb4f20 编写于 作者: X xj.lin

add list => array

上级 5616ec74
from engine.retrieval import search_index
from engine.ingestion import build_index
from engine.ingestion import serialize
import numpy as np
class Singleton(type):
_instances = {}
......@@ -15,36 +15,29 @@ class Scheduler(metaclass=Singleton):
def Search(self, index_file_key, vectors, k):
# assert index_file_key
# assert vectors
# assert k
assert k != 0
query_vectors = serialize.to_array(vectors)
return self.__scheduler(index_file_key, vectors, k)
return self.__scheduler(index_file_key, query_vectors, k)
def __scheduler(self, index_data_key, vectors, k):
result_list = []
d = None
raw_vectors = None
print("__scheduler: vectors: ", vectors)
query_vectors = np.asarray(vectors).astype('float32')
if 'raw' in index_data_key:
raw_vectors = index_data_key['raw']
raw_vectors = np.asarray(raw_vectors).astype('float32')
d = index_data_key['dimension']
if 'raw' in index_data_key:
index_builder = build_index.FactoryIndex()
print("d: ", d, " raw_vectors: ", raw_vectors)
index = index_builder().build(d, raw_vectors)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(query_vectors, k))
result_list.append(searcher.search_by_vectors(vectors, k))
index_data_list = index_data_key['index']
for key in index_data_list:
index = GetIndexData(key)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(query_vectors, k))
result_list.append(searcher.search_by_vectors(vectors, k))
if len(result_list) == 1:
return result_list[0].vectors
......
......@@ -3,6 +3,7 @@ from engine.settings import DATABASE_DIRECTORY
from flask import jsonify
import pytest
import os
import numpy as np
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
......@@ -104,10 +105,12 @@ class TestVectorEngine:
expected_list = [self.__vector]
vector_list = VectorEngine.GetVectorListFromRawFile('test_group', filename)
print('expected_list: ', expected_list)
print('vector_list: ', vector_list)
expected_list = np.asarray(expected_list).astype('float32')
assert vector_list == expected_list
assert np.all(vector_list == expected_list)
code = VectorEngine.ClearRawFile('test_group')
assert code == VectorEngine.SUCCESS_CODE
......
......@@ -144,7 +144,7 @@ class VectorEngine(object):
index_keys = [ i.filename for i in files if i.type == 'index' ]
index_map = {}
index_map['index'] = index_keys
index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename")
index_map['raw'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage
index_map['dimension'] = group.dimension
scheduler_instance = Scheduler()
......@@ -188,8 +188,7 @@ class VectorEngine(object):
@staticmethod
def GetVectorListFromRawFile(group_id, filename="todo"):
return VectorEngine.group_dict[group_id]
# return serialize.to_array(VectorEngine.group_dict[group_id])
return serialize.to_array(VectorEngine.group_dict[group_id])
@staticmethod
def ClearRawFile(group_id):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册