diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index 5ba749d4536a9c9dafe9d93f3f62b2345c0a84ec..26ee6a7baaa7c5876dc8636c7f5c94fabdadfbd5 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -67,28 +67,31 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) { //} void -GetResultData(std::vector>& search_records, - std::vector& search_results, - int64_t nq, - int64_t topk) { +ReduceResultData(std::vector& search_results, int64_t nq, int64_t topk) { AssertInfo(topk > 0, "topk must greater than 0"); auto num_segments = search_results.size(); AssertInfo(num_segments > 0, "num segment must greater than 0"); + for (int i = 0; i < num_segments; i++) { + auto search_result = search_results[i]; + AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); + AssertInfo(search_result->primary_keys_.size() == nq * topk, "incorrect search result primary key size"); + AssertInfo(search_result->result_distances_.size() == nq * topk, "incorrect search result distance size"); + } + std::vector> search_records(num_segments); std::unordered_set pk_set; int64_t skip_dup_cnt = 0; + + // reduce search results for (int64_t qi = 0; qi < nq; qi++) { std::vector result_pairs; int64_t base_offset = qi * topk; - for (int j = 0; j < num_segments; ++j) { - auto search_result = search_results[j]; - AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); - AssertInfo(search_result->primary_keys_.size() == nq * topk, "incorrect search result primary key size"); - AssertInfo(search_result->result_distances_.size() == nq * topk, "incorrect search result distance size"); + for (int i = 0; i < num_segments; i++) { + auto search_result = search_results[i]; auto primary_key = search_result->primary_keys_[base_offset]; auto distance = search_result->result_distances_[base_offset]; result_pairs.push_back( - SearchResultPair(primary_key, distance, search_result, j, base_offset, base_offset + topk)); + SearchResultPair(primary_key, distance, search_result, i, base_offset, base_offset + topk)); } int64_t curr_offset = base_offset; @@ -125,18 +128,11 @@ GetResultData(std::vector>& search_records, } #endif } - if (skip_dup_cnt > 0) { - LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt; - } -} + LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt; -void -ResetSearchResult(std::vector>& search_records, std::vector& search_results) { - auto num_segments = search_results.size(); - AssertInfo(num_segments > 0, "num segment must greater than 0"); + // after reduce, remove redundant values in primary_keys, result_distances and internal_seg_offsets for (int i = 0; i < num_segments; i++) { auto search_result = search_results[i]; - AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); if (search_result->result_offsets_.size() == 0) { continue; } @@ -144,23 +140,12 @@ ResetSearchResult(std::vector>& search_records, std::vector std::vector primary_keys; std::vector result_distances; std::vector internal_seg_offsets; - int64_t primary_key; - float distance; - int64_t internal_seg_offset; for (int j = 0; j < search_records[i].size(); j++) { auto& offset = search_records[i][j]; - if (offset != INVALID_OFFSET) { - primary_key = search_result->primary_keys_[offset]; - distance = search_result->result_distances_[offset]; - internal_seg_offset = search_result->internal_seg_offsets_[offset]; - } else { - primary_key = INVALID_ID; - distance = MAXFLOAT; - internal_seg_offset = INVALID_SEG_OFFSET; - } - primary_keys.push_back(primary_key); - result_distances.push_back(distance); - internal_seg_offsets.push_back(internal_seg_offset); + primary_keys.push_back(offset != INVALID_OFFSET ? search_result->primary_keys_[offset] : INVALID_ID); + result_distances.push_back(offset != INVALID_OFFSET ? search_result->result_distances_[offset] : MAXFLOAT); + internal_seg_offsets.push_back(offset != INVALID_OFFSET ? search_result->internal_seg_offsets_[offset] + : INVALID_SEG_OFFSET); } search_result->primary_keys_ = primary_keys; @@ -179,7 +164,6 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul } auto topk = search_results[0]->topk_; auto num_queries = search_results[0]->num_queries_; - std::vector> search_records(num_segments); // get primary keys for duplicates removal for (auto& search_result : search_results) { @@ -187,8 +171,7 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul segment->FillPrimaryKeys(plan, *search_result); } - GetResultData(search_records, search_results, num_queries, topk); - ResetSearchResult(search_records, search_results); + ReduceResultData(search_results, num_queries, topk); // fill in other entities for (auto& search_result : search_results) {