未验证 提交 315aa0af 编写于 作者: C Cai Yudong 提交者: GitHub

Merge GetResultData and ResetResultData into ReduceResultData (#11396)

Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 db6ecd58
......@@ -67,28 +67,31 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
//}
void
GetResultData(std::vector<std::vector<int64_t>>& search_records,
std::vector<SearchResult*>& search_results,
int64_t nq,
int64_t topk) {
ReduceResultData(std::vector<SearchResult*>& search_results, int64_t nq, int64_t topk) {
AssertInfo(topk > 0, "topk must greater than 0");
auto num_segments = search_results.size();
AssertInfo(num_segments > 0, "num segment must greater than 0");
for (int i = 0; i < num_segments; i++) {
auto search_result = search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
AssertInfo(search_result->primary_keys_.size() == nq * topk, "incorrect search result primary key size");
AssertInfo(search_result->result_distances_.size() == nq * topk, "incorrect search result distance size");
}
std::vector<std::vector<int64_t>> search_records(num_segments);
std::unordered_set<int64_t> pk_set;
int64_t skip_dup_cnt = 0;
// reduce search results
for (int64_t qi = 0; qi < nq; qi++) {
std::vector<SearchResultPair> result_pairs;
int64_t base_offset = qi * topk;
for (int j = 0; j < num_segments; ++j) {
auto search_result = search_results[j];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
AssertInfo(search_result->primary_keys_.size() == nq * topk, "incorrect search result primary key size");
AssertInfo(search_result->result_distances_.size() == nq * topk, "incorrect search result distance size");
for (int i = 0; i < num_segments; i++) {
auto search_result = search_results[i];
auto primary_key = search_result->primary_keys_[base_offset];
auto distance = search_result->result_distances_[base_offset];
result_pairs.push_back(
SearchResultPair(primary_key, distance, search_result, j, base_offset, base_offset + topk));
SearchResultPair(primary_key, distance, search_result, i, base_offset, base_offset + topk));
}
int64_t curr_offset = base_offset;
......@@ -125,18 +128,11 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
}
#endif
}
if (skip_dup_cnt > 0) {
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
}
}
void
ResetSearchResult(std::vector<std::vector<int64_t>>& search_records, std::vector<SearchResult*>& search_results) {
auto num_segments = search_results.size();
AssertInfo(num_segments > 0, "num segment must greater than 0");
// after reduce, remove redundant values in primary_keys, result_distances and internal_seg_offsets
for (int i = 0; i < num_segments; i++) {
auto search_result = search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
if (search_result->result_offsets_.size() == 0) {
continue;
}
......@@ -144,23 +140,12 @@ ResetSearchResult(std::vector<std::vector<int64_t>>& search_records, std::vector
std::vector<int64_t> primary_keys;
std::vector<float> result_distances;
std::vector<int64_t> internal_seg_offsets;
int64_t primary_key;
float distance;
int64_t internal_seg_offset;
for (int j = 0; j < search_records[i].size(); j++) {
auto& offset = search_records[i][j];
if (offset != INVALID_OFFSET) {
primary_key = search_result->primary_keys_[offset];
distance = search_result->result_distances_[offset];
internal_seg_offset = search_result->internal_seg_offsets_[offset];
} else {
primary_key = INVALID_ID;
distance = MAXFLOAT;
internal_seg_offset = INVALID_SEG_OFFSET;
}
primary_keys.push_back(primary_key);
result_distances.push_back(distance);
internal_seg_offsets.push_back(internal_seg_offset);
primary_keys.push_back(offset != INVALID_OFFSET ? search_result->primary_keys_[offset] : INVALID_ID);
result_distances.push_back(offset != INVALID_OFFSET ? search_result->result_distances_[offset] : MAXFLOAT);
internal_seg_offsets.push_back(offset != INVALID_OFFSET ? search_result->internal_seg_offsets_[offset]
: INVALID_SEG_OFFSET);
}
search_result->primary_keys_ = primary_keys;
......@@ -179,7 +164,6 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
}
auto topk = search_results[0]->topk_;
auto num_queries = search_results[0]->num_queries_;
std::vector<std::vector<int64_t>> search_records(num_segments);
// get primary keys for duplicates removal
for (auto& search_result : search_results) {
......@@ -187,8 +171,7 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
segment->FillPrimaryKeys(plan, *search_result);
}
GetResultData(search_records, search_results, num_queries, topk);
ResetSearchResult(search_records, search_results);
ReduceResultData(search_results, num_queries, topk);
// fill in other entities
for (auto& search_result : search_results) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册