diff --git a/shards/mishards/service_handler.py b/shards/mishards/service_handler.py index 620f6213de50dec8419b1db730e4590bfd9a7d4f..640ae61ba805be1b1139d61d7100d890fcbcd821 100644 --- a/shards/mishards/service_handler.py +++ b/shards/mishards/service_handler.py @@ -34,13 +34,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return source_ids, source_diss if diss[k - 1] <= source_diss[0]: return ids, diss - - diss_t = enumerate(source_diss.extend(diss)) + + source_diss.extend(diss) + diss_t = enumerate(source_diss) diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k] diss_m_out = [id_ for _, id_ in diss_m_rst] - id_t = source_ids.extend(ids) - id_m_out = [id_t[i] for i, _ in diss_m_rst] + source_ids.extend(ids) + id_m_out = [source_ids[i] for i, _ in diss_m_rst] return id_m_out, diss_m_out @@ -50,8 +51,6 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): if not files_n_topk_results: return status, [] - # request_results = defaultdict(list) - # row_num = files_n_topk_results[0].row_num merge_id_results = [] merge_dis_results = [] @@ -64,6 +63,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): row_num = files_collection.row_num ids = files_collection.ids diss = files_collection.distances # distance collections + # TODO: batch_len is equal to topk batch_len = len(ids) // row_num for row_index in range(row_num): @@ -77,28 +77,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): merge_id_results.append(id_batch) merge_dis_results.append(dis_batch) else: - merge_id_results[row_index].extend(ids[row_index * batch_len, (row_index + 1) * batch_len]) - merge_dis_results[row_index].extend(diss[row_index * batch_len, (row_index + 1) * batch_len]) - # _reduce(_ids, _diss, k, reverse) merge_id_results[row_index], merge_dis_results[row_index] = \ self._reduce(merge_id_results[row_index], id_batch, merge_dis_results[row_index], dis_batch, batch_len, reverse) - # for request_pos, each_request_results in enumerate( - # files_collection.topk_query_result): - # request_results[request_pos].extend( - # each_request_results.query_result_arrays) - # request_results[request_pos] = sorted( - # request_results[request_pos], - # key=lambda x: x.distance, - # reverse=reverse)[:topk] calc_time = time.time() - calc_time logger.info('Merge takes {}'.format(calc_time)) - # results = sorted(request_results.items()) id_mrege_list = [] dis_mrege_list = [] @@ -106,10 +94,6 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): id_mrege_list.extend(id_results) dis_mrege_list.extend(dis_results) - # for result in results: - # query_result = TopKQueryResult(query_result_arrays=result[1]) - # topk_query_result.append(query_result) - return status, id_mrege_list, dis_mrege_list def _do_query(self,