未验证 提交 9cc4d0bf 编写于 作者: J Jin Hai 提交者: GitHub

Merge pull request #424 from BossZou/0.6.0

0.6.0
......@@ -9,3 +9,5 @@ output.info
output_new.info
server.info
*.pyc
src/grpc/python_gen.h
src/grpc/python/
......@@ -2,6 +2,7 @@ import logging
import threading
from functools import wraps
from milvus import Milvus
from milvus.client.hooks import BaseaSearchHook
from mishards import (settings, exceptions)
from utils import singleton
......@@ -9,6 +10,12 @@ from utils import singleton
logger = logging.getLogger(__name__)
class Searchook(BaseaSearchHook):
def on_response(self, *args, **kwargs):
return True
class Connection:
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
self.name = name
......@@ -18,6 +25,9 @@ class Connection:
self.conn = Milvus()
self.error_handlers = [] if not error_handlers else error_handlers
self.on_retry_func = kwargs.get('on_retry_func', None)
# define search hook
self.conn._set_hook(search_in_file=Searchook())
# self._connect()
def __str__(self):
......
......@@ -29,39 +29,71 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
self.router = router
self.max_workers = max_workers
def _reduce(self, source_ids, ids, source_diss, diss, k, reverse):
if source_diss[k - 1] <= diss[0]:
return source_ids, source_diss
if diss[k - 1] <= source_diss[0]:
return ids, diss
source_diss.extend(diss)
diss_t = enumerate(source_diss)
diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k]
diss_m_out = [id_ for _, id_ in diss_m_rst]
source_ids.extend(ids)
id_m_out = [source_ids[i] for i, _ in diss_m_rst]
return id_m_out, diss_m_out
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
status = status_pb2.Status(error_code=status_pb2.SUCCESS,
reason="Success")
if not files_n_topk_results:
return status, []
request_results = defaultdict(list)
merge_id_results = []
merge_dis_results = []
calc_time = time.time()
for files_collection in files_n_topk_results:
if isinstance(files_collection, tuple):
status, _ = files_collection
return status, []
for request_pos, each_request_results in enumerate(
files_collection.topk_query_result):
request_results[request_pos].extend(
each_request_results.query_result_arrays)
request_results[request_pos] = sorted(
request_results[request_pos],
key=lambda x: x.distance,
reverse=reverse)[:topk]
row_num = files_collection.row_num
ids = files_collection.ids
diss = files_collection.distances # distance collections
# TODO: batch_len is equal to topk, may need to compare with topk
batch_len = len(ids) // row_num
for row_index in range(row_num):
id_batch = ids[row_index * batch_len: (row_index + 1) * batch_len]
dis_batch = diss[row_index * batch_len: (row_index + 1) * batch_len]
if len(merge_id_results) < row_index:
raise ValueError("merge error")
elif len(merge_id_results) == row_index:
# TODO: may bug here
merge_id_results.append(id_batch)
merge_dis_results.append(dis_batch)
else:
merge_id_results[row_index], merge_dis_results[row_index] = \
self._reduce(merge_id_results[row_index], id_batch,
merge_dis_results[row_index], dis_batch,
batch_len,
reverse)
calc_time = time.time() - calc_time
logger.info('Merge takes {}'.format(calc_time))
results = sorted(request_results.items())
topk_query_result = []
id_mrege_list = []
dis_mrege_list = []
for result in results:
query_result = TopKQueryResult(query_result_arrays=result[1])
topk_query_result.append(query_result)
for id_results, dis_results in zip(merge_id_results, merge_dis_results):
id_mrege_list.extend(id_results)
dis_mrege_list.extend(dis_results)
return status, topk_query_result
return status, id_mrege_list, dis_mrege_list
def _do_query(self,
context,
......@@ -109,8 +141,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
file_ids=query_params['file_ids'],
query_records=vectors,
top_k=topk,
nprobe=nprobe,
lazy_=True)
nprobe=nprobe
)
end = time.time()
logger.info('search_vectors_in_files takes: {}'.format(end - start))
......@@ -241,7 +273,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger.info('Search {}: topk={} nprobe={}'.format(
table_name, topk, nprobe))
metadata = {'resp_class': milvus_pb2.TopKQueryResultList}
metadata = {'resp_class': milvus_pb2.TopKQueryResult}
if nprobe > self.MAX_NPROBE or nprobe <= 0:
raise exceptions.InvalidArgumentError(
......@@ -275,22 +307,24 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
query_range_array.append(
Range(query_range.start_value, query_range.end_value))
status, results = self._do_query(context,
table_name,
table_meta,
query_record_array,
topk,
nprobe,
query_range_array,
metadata=metadata)
status, id_results, dis_results = self._do_query(context,
table_name,
table_meta,
query_record_array,
topk,
nprobe,
query_range_array,
metadata=metadata)
now = time.time()
logger.info('SearchVector takes: {}'.format(now - start))
topk_result_list = milvus_pb2.TopKQueryResultList(
topk_result_list = milvus_pb2.TopKQueryResult(
status=status_pb2.Status(error_code=status.error_code,
reason=status.reason),
topk_query_result=results)
row_num=len(query_record_array),
ids=id_results,
distances=dis_results)
return topk_result_list
@mark_grpc_method
......
......@@ -14,8 +14,7 @@ py==1.8.0
pyasn1==0.4.7
pyasn1-modules==0.2.6
pylint==2.3.1
pymilvus-test==0.2.28
#pymilvus==0.2.0
pymilvus==0.2.5
pyparsing==2.4.0
pytest==4.6.3
pytest-level==0.1.1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册