提交 2f8be3d0 编写于 作者: Y yhz

finish results reduce in mishards

上级 83d9bf69
......@@ -34,13 +34,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return source_ids, source_diss
if diss[k - 1] <= source_diss[0]:
return ids, diss
diss_t = enumerate(source_diss.extend(diss))
source_diss.extend(diss)
diss_t = enumerate(source_diss)
diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k]
diss_m_out = [id_ for _, id_ in diss_m_rst]
id_t = source_ids.extend(ids)
id_m_out = [id_t[i] for i, _ in diss_m_rst]
source_ids.extend(ids)
id_m_out = [source_ids[i] for i, _ in diss_m_rst]
return id_m_out, diss_m_out
......@@ -50,8 +51,6 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
if not files_n_topk_results:
return status, []
# request_results = defaultdict(list)
# row_num = files_n_topk_results[0].row_num
merge_id_results = []
merge_dis_results = []
......@@ -64,6 +63,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
row_num = files_collection.row_num
ids = files_collection.ids
diss = files_collection.distances # distance collections
# TODO: batch_len is equal to topk
batch_len = len(ids) // row_num
for row_index in range(row_num):
......@@ -77,28 +77,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
merge_id_results.append(id_batch)
merge_dis_results.append(dis_batch)
else:
merge_id_results[row_index].extend(ids[row_index * batch_len, (row_index + 1) * batch_len])
merge_dis_results[row_index].extend(diss[row_index * batch_len, (row_index + 1) * batch_len])
# _reduce(_ids, _diss, k, reverse)
merge_id_results[row_index], merge_dis_results[row_index] = \
self._reduce(merge_id_results[row_index], id_batch,
merge_dis_results[row_index], dis_batch,
batch_len,
reverse)
# for request_pos, each_request_results in enumerate(
# files_collection.topk_query_result):
# request_results[request_pos].extend(
# each_request_results.query_result_arrays)
# request_results[request_pos] = sorted(
# request_results[request_pos],
# key=lambda x: x.distance,
# reverse=reverse)[:topk]
calc_time = time.time() - calc_time
logger.info('Merge takes {}'.format(calc_time))
# results = sorted(request_results.items())
id_mrege_list = []
dis_mrege_list = []
......@@ -106,10 +94,6 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
id_mrege_list.extend(id_results)
dis_mrege_list.extend(dis_results)
# for result in results:
# query_result = TopKQueryResult(query_result_arrays=result[1])
# topk_query_result.append(query_result)
return status, id_mrege_list, dis_mrege_list
def _do_query(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册