未验证 提交 765907ab 编写于 作者: C Cai Yudong 提交者: GitHub

Optimize segcore Reduce (#18902)

Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 495b214d
......@@ -28,127 +28,40 @@ ReduceHelper::Initialize() {
AssertInfo(slice_nqs_.size() > 0, "empty slice_nqs");
AssertInfo(slice_nqs_.size() == slice_topKs_.size(), "unaligned slice_nqs and slice_topKs");
unify_topK_ = search_results_[0]->unity_topK_;
total_nq_ = search_results_[0]->total_nq_;
num_segments_ = search_results_.size();
num_slices_ = slice_nqs_.size();
// prefix sum, get slices offsets
AssertInfo(num_slices_ > 0, "empty slice_nqs is not allowed");
auto slice_offsets_size = num_slices_ + 1;
nq_slice_offsets_ = std::vector<int32_t>(slice_offsets_size);
for (int i = 1; i < slice_offsets_size; i++) {
nq_slice_offsets_[i] = nq_slice_offsets_[i - 1] + slice_nqs_[i - 1];
for (auto j = nq_slice_offsets_[i - 1]; j < nq_slice_offsets_[i]; j++) {
}
}
AssertInfo(nq_slice_offsets_[num_slices_] == total_nq_,
"illegal req sizes"
", nq_slice_offsets[last] = " +
std::to_string(nq_slice_offsets_[num_slices_]) + ", total_nq = " + std::to_string(total_nq_));
slice_nqs_prefix_sum_.resize(num_slices_ + 1);
std::partial_sum(slice_nqs_.begin(), slice_nqs_.end(), slice_nqs_prefix_sum_.begin() + 1);
AssertInfo(slice_nqs_prefix_sum_[num_slices_] == total_nq_, "illegal req sizes, slice_nqs_prefix_sum_[last] = " +
std::to_string(slice_nqs_prefix_sum_[num_slices_]) +
", total_nq = " + std::to_string(total_nq_));
// init final_search_records and final_read_topKs
final_search_records_ = std::vector<std::vector<int64_t>>(num_segments_);
final_real_topKs_ = std::vector<std::vector<int64_t>>(num_segments_);
for (auto& topKs : final_real_topKs_) {
// `topKs` records real topK of each query
topKs.resize(total_nq_);
final_search_records_.resize(num_segments_);
for (auto& search_record : final_search_records_) {
search_record.resize(total_nq_);
}
}
void
ReduceHelper::Reduce() {
std::vector<SearchResult*> valid_search_results;
// get primary keys for duplicates removal
for (auto search_result : search_results_) {
FilterInvalidSearchResult(search_result);
if (search_result->get_total_result_count() > 0) {
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
segment->FillPrimaryKeys(plan_, *search_result);
valid_search_results.emplace_back(search_result);
}
}
search_results_ = valid_search_results;
num_segments_ = search_results_.size();
if (valid_search_results.size() == 0) {
// TODO: return empty search result?
return;
}
for (int i = 0; i < num_slices_; i++) {
// ReduceResultData for each slice
ReduceResultData(i);
}
// after reduce, remove invalid primary_keys, distances and ids by `final_search_records`
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_[i];
if (search_result->result_offsets_.size() != 0) {
std::vector<milvus::PkType> primary_keys;
std::vector<float> distances;
std::vector<int64_t> seg_offsets;
for (int j = 0; j < final_search_records_[i].size(); j++) {
auto& offset = final_search_records_[i][j];
primary_keys.push_back(search_result->primary_keys_[offset]);
distances.push_back(search_result->distances_[offset]);
seg_offsets.push_back(search_result->seg_offsets_[offset]);
}
search_result->primary_keys_ = std::move(primary_keys);
search_result->distances_ = std::move(distances);
search_result->seg_offsets_ = std::move(seg_offsets);
}
search_result->topk_per_nq_prefix_sum_.resize(final_real_topKs_[i].size() + 1);
std::partial_sum(final_real_topKs_[i].begin(), final_real_topKs_[i].end(),
search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
// fill target entry
for (auto& search_result : search_results_) {
auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_);
segment->FillTargetEntry(plan_, *search_result);
}
FillPrimaryKey();
ReduceResultData();
RefreshSearchResult();
FillEntryData();
}
void
ReduceHelper::Marshal() {
// example:
// ----------------------------------
// nq0 nq1 nq2
// sr0 topk00 topk01 topk02
// sr1 topk10 topk11 topk12
// ----------------------------------
// then:
// result_slice_offsets[] = {
// 0,
// == sr0->topk_per_nq_prefix_sum_[0] + sr1->topk_per_nq_prefix_sum_[0]
// ((topk00) + (topk10)),
// == sr0->topk_per_nq_prefix_sum_[1] + sr1->topk_per_nq_prefix_sum_[1]
// ((topk00 + topk01) + (topk10 + topk11)),
// == sr0->topk_per_nq_prefix_sum_[2] + sr1->topk_per_nq_prefix_sum_[2]
// ((topk00 + topk01 + topk02) + (topk10 + topk11 + topk12)),
// == sr0->topk_per_nq_prefix_sum_[3] + sr1->topk_per_nq_prefix_sum_[3]
// }
auto result_slice_offsets = std::vector<int64_t>(nq_slice_offsets_.size(), 0);
for (auto search_result : search_results_) {
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1,
"incorrect topk_per_nq_prefix_sum_ size in search result");
for (int i = 1; i < nq_slice_offsets_.size(); i++) {
result_slice_offsets[i] += search_result->topk_per_nq_prefix_sum_[nq_slice_offsets_[i]];
}
}
AssertInfo(result_slice_offsets[num_slices_] <= total_nq_ * unify_topK_,
"illegal result_slice_offsets when Marshal, result_slice_offsets[last] = " +
std::to_string(result_slice_offsets[num_slices_]) + ", total_nq = " + std::to_string(total_nq_) +
", unify_topK = " + std::to_string(unify_topK_));
// get search result data blobs of slices
search_result_data_blobs_ = std::make_unique<milvus::segcore::SearchResultDataBlobs>();
search_result_data_blobs_->blobs.resize(num_slices_);
//#pragma omp parallel for
for (int i = 0; i < num_slices_; i++) {
auto result_count = result_slice_offsets[i + 1] - result_slice_offsets[i];
auto proto = GetSearchResultDataSlice(i, result_count);
auto proto = GetSearchResultDataSlice(i);
search_result_data_blobs_->blobs[i] = proto;
}
}
......@@ -178,102 +91,152 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
}
}
search_result->distances_ = std::move(distances);
search_result->seg_offsets_ = std::move(seg_offsets);
search_result->distances_.swap(distances);
search_result->seg_offsets_.swap(seg_offsets);
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
void
ReduceHelper::ReduceResultData(int slice_index) {
ReduceHelper::FillPrimaryKey() {
std::vector<SearchResult*> valid_search_results;
// get primary keys for duplicates removal
for (auto search_result : search_results_) {
FilterInvalidSearchResult(search_result);
if (search_result->get_total_result_count() > 0) {
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
segment->FillPrimaryKeys(plan_, *search_result);
valid_search_results.emplace_back(search_result);
}
}
search_results_.swap(valid_search_results);
num_segments_ = search_results_.size();
}
void
ReduceHelper::RefreshSearchResult() {
for (int i = 0; i < num_segments_; i++) {
std::vector<int64_t> real_topks(total_nq_, 0);
auto search_result = search_results_[i];
auto result_count = search_result->get_total_result_count();
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size");
AssertInfo(search_result->distances_.size() == result_count, "incorrect search result distance size");
if (search_result->result_offsets_.size() != 0) {
std::vector<milvus::PkType> primary_keys;
std::vector<float> distances;
std::vector<int64_t> seg_offsets;
for (int j = 0; j < total_nq_; j++) {
for (auto offset : final_search_records_[i][j]) {
primary_keys.push_back(search_result->primary_keys_[offset]);
distances.push_back(search_result->distances_[offset]);
seg_offsets.push_back(search_result->seg_offsets_[offset]);
real_topks[j]++;
}
}
search_result->primary_keys_ = std::move(primary_keys);
search_result->distances_ = std::move(distances);
search_result->seg_offsets_ = std::move(seg_offsets);
}
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
}
auto nq_offset_begin = nq_slice_offsets_[slice_index];
auto nq_offset_end = nq_slice_offsets_[slice_index + 1];
AssertInfo(nq_offset_begin < nq_offset_end,
"illegal nq offsets when ReduceResultData, nq_offset_begin = " + std::to_string(nq_offset_begin) +
", nq_offset_end = " + std::to_string(nq_offset_end));
// `search_records` records the search result offsets
std::vector<std::vector<int64_t>> search_records(num_segments_);
int64_t skip_dup_cnt = 0;
void
ReduceHelper::FillEntryData() {
for (auto search_result : search_results_) {
auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_);
segment->FillTargetEntry(plan_, *search_result);
}
}
// reduce search results
int64_t result_offset = 0;
for (int64_t qi = nq_offset_begin; qi < nq_offset_end; qi++) {
int64_t
ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offset) {
std::vector<SearchResultPair> result_pairs;
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_[i];
if (search_result->topk_per_nq_prefix_sum_[qi + 1] - search_result->topk_per_nq_prefix_sum_[qi] == 0) {
auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
auto offset_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
if (offset_beg == offset_end) {
continue;
}
auto base_offset = search_result->topk_per_nq_prefix_sum_[qi];
auto primary_key = search_result->primary_keys_[base_offset];
auto distance = search_result->distances_[base_offset];
result_pairs.emplace_back(primary_key, distance, search_result, i, base_offset,
search_result->topk_per_nq_prefix_sum_[qi + 1]);
auto primary_key = search_result->primary_keys_[offset_beg];
auto distance = search_result->distances_[offset_beg];
result_pairs.emplace_back(primary_key, distance, search_result, i, offset_beg, offset_end);
}
// nq has no results for all segments
if (result_pairs.size() == 0) {
continue;
return 0;
}
int64_t dup_cnt = 0;
std::unordered_set<milvus::PkType> pk_set;
int64_t last_nq_result_offset = result_offset;
while (result_offset - last_nq_result_offset < slice_topKs_[slice_index]) {
int64_t prev_offset = offset;
while (offset - prev_offset < topk) {
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
auto& pilot = result_pairs[0];
auto index = pilot.segment_index_;
auto curr_pk = pilot.primary_key_;
auto pk = pilot.primary_key_;
// no valid search result for this nq, break to next
if (curr_pk == INVALID_PK) {
if (pk == INVALID_PK) {
break;
}
// remove duplicates
if (pk_set.count(curr_pk) == 0) {
pilot.search_result_->result_offsets_.push_back(result_offset++);
search_records[index].push_back(pilot.offset_);
pk_set.insert(curr_pk);
final_real_topKs_[index][qi]++;
if (pk_set.count(pk) == 0) {
pilot.search_result_->result_offsets_.push_back(offset++);
final_search_records_[index][qi].push_back(pilot.offset_);
pk_set.insert(pk);
} else {
// skip entity with same primary key
skip_dup_cnt++;
dup_cnt++;
}
pilot.reset();
}
}
return dup_cnt;
}
if (skip_dup_cnt > 0) {
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
void
ReduceHelper::ReduceResultData() {
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_[i];
auto result_count = search_result->get_total_result_count();
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
AssertInfo(search_result->distances_.size() == result_count, "incorrect search result distance size");
AssertInfo(search_result->seg_offsets_.size() == result_count, "incorrect search result seg offset size");
AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size");
}
// append search_records to final_search_records
for (int i = 0; i < num_segments_; i++) {
for (int j = 0; j < search_records[i].size(); j++) {
final_search_records_[i].emplace_back(search_records[i][j]);
int64_t skip_dup_cnt = 0;
for (int64_t slice_index = 0; slice_index < num_slices_; slice_index++) {
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
// reduce search results
int64_t result_offset = 0;
for (int64_t qi = nq_begin; qi < nq_end; qi++) {
skip_dup_cnt += ReduceSearchResultForOneNQ(qi, slice_topKs_[slice_index], result_offset);
}
}
if (skip_dup_cnt > 0) {
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
}
}
std::vector<char>
ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
auto nq_offset_begin = nq_slice_offsets_[slice_index_];
auto nq_offset_end = nq_slice_offsets_[slice_index_ + 1];
AssertInfo(nq_offset_begin <= nq_offset_end,
"illegal offsets when GetSearchResultDataSlice, nq_offset_begin = " + std::to_string(nq_offset_begin) +
", nq_offset_end = " + std::to_string(nq_offset_end));
ReduceHelper::GetSearchResultDataSlice(int slice_index) {
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
int64_t result_count = 0;
for (auto search_result : search_results_) {
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1,
"incorrect topk_per_nq_prefix_sum_ size in search result");
result_count +=
search_result->topk_per_nq_prefix_sum_[nq_end] - search_result->topk_per_nq_prefix_sum_[nq_begin];
}
auto search_result_data = std::make_unique<milvus::proto::schema::SearchResultData>();
// set unify_topK and total_nq
search_result_data->set_top_k(slice_topKs_[slice_index_]);
search_result_data->set_num_queries(nq_offset_end - nq_offset_begin);
search_result_data->mutable_topks()->Resize(nq_offset_end - nq_offset_begin, 0);
search_result_data->set_top_k(slice_topKs_[slice_index]);
search_result_data->set_num_queries(nq_end - nq_begin);
search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0);
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);
......@@ -306,19 +269,20 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
search_result_data->mutable_scores()->Resize(result_count, 0);
// fill pks and distances
for (auto nq_offset = nq_offset_begin; nq_offset < nq_offset_end; nq_offset++) {
int64_t topK_count = 0;
for (int i = 0; i < search_results_.size(); i++) {
auto search_result = search_results_[i];
for (auto qi = nq_begin; qi < nq_end; qi++) {
int64_t topk_count = 0;
for (auto search_result : search_results_) {
AssertInfo(search_result != nullptr, "null search result when reorganize");
if (search_result->result_offsets_.size() == 0) {
continue;
}
auto result_start = search_result->topk_per_nq_prefix_sum_[nq_offset];
auto result_end = search_result->topk_per_nq_prefix_sum_[nq_offset + 1];
for (auto offset = result_start; offset < result_end; offset++) {
auto loc = search_result->result_offsets_[offset];
auto topk_start = search_result->topk_per_nq_prefix_sum_[qi];
auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
topk_count += topk_end - topk_start;
for (auto ki = topk_start; ki < topk_end; ki++) {
auto loc = search_result->result_offsets_[ki];
AssertInfo(loc < result_count && loc >= 0,
"invalid loc when GetSearchResultDataSlice, loc = " + std::to_string(loc) +
", result_count = " + std::to_string(result_count));
......@@ -326,12 +290,12 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
switch (pk_type) {
case milvus::DataType::INT64: {
search_result_data->mutable_ids()->mutable_int_id()->mutable_data()->Set(
loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[offset]));
loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[ki]));
break;
}
case milvus::DataType::VARCHAR: {
*search_result_data->mutable_ids()->mutable_str_id()->mutable_data()->Mutable(loc) =
std::visit(StrPKVisitor{}, search_result->primary_keys_[offset]);
std::visit(StrPKVisitor{}, search_result->primary_keys_[ki]);
break;
}
default: {
......@@ -340,17 +304,14 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
}
// set result distances
search_result_data->mutable_scores()->Set(loc, search_result->distances_[offset]);
search_result_data->mutable_scores()->Set(loc, search_result->distances_[ki]);
// set result offset to fill output fields data
result_pairs[loc] = std::make_pair(search_result, offset);
result_pairs[loc] = std::make_pair(search_result, ki);
}
topK_count += search_result->topk_per_nq_prefix_sum_[nq_offset + 1] -
search_result->topk_per_nq_prefix_sum_[nq_offset];
}
// update result topKs
search_result_data->mutable_topks()->Set(nq_offset - nq_offset_begin, topK_count);
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
}
AssertInfo(search_result_data->scores_size() == result_count,
......
......@@ -61,15 +61,26 @@ class ReduceHelper {
FilterInvalidSearchResult(SearchResult* search_result);
void
ReduceResultData(int slice_index);
FillPrimaryKey();
void
RefreshSearchResult();
void
FillEntryData();
int64_t
ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& result_offset);
void
ReduceResultData();
std::vector<char>
GetSearchResultDataSlice(int slice_index_, int64_t result_count);
GetSearchResultDataSlice(int slice_index_);
private:
std::vector<int64_t> slice_topKs_;
std::vector<int64_t> slice_nqs_;
int64_t unify_topK_;
int64_t total_nq_;
int64_t num_segments_;
int64_t num_slices_;
......@@ -77,10 +88,10 @@ class ReduceHelper {
milvus::query::Plan* plan_;
std::vector<SearchResult*>& search_results_;
//
std::vector<int32_t> nq_slice_offsets_;
std::vector<std::vector<int64_t>> final_search_records_;
std::vector<std::vector<int64_t>> final_real_topKs_;
std::vector<int64_t> slice_nqs_prefix_sum_;
// dim0: num_segments_; dim1: total_nq_; dim2: offset
std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
// output
std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
......
......@@ -1347,25 +1347,16 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) {
auto suc = search_result_data.ParseFromArray(search_result_data_blobs->blobs[i].data(),
search_result_data_blobs->blobs[i].size());
assert(suc);
assert(suc);
assert(search_result_data.num_queries() == slice_nqs[i]);
assert(search_result_data.top_k() == slice_topKs[i]);
assert(search_result_data.scores().size() == slice_topKs[i] * slice_nqs[i]);
assert(search_result_data.ids().int_id().data_size() == slice_topKs[i] * slice_nqs[i]);
assert(search_result_data.scores().size() == search_result_data.topks().at(0) * slice_nqs[i]);
assert(search_result_data.ids().int_id().data_size() == search_result_data.topks().at(0) * slice_nqs[i]);
// check topKs
// check real topks
assert(search_result_data.topks().size() == slice_nqs[i]);
for (int j = 0; j < search_result_data.topks().size(); j++) {
assert(search_result_data.topks().at(j) == slice_topKs[i]);
for (auto real_topk : search_result_data.topks()) {
assert(real_topk <= slice_topKs[i]);
}
// assert(search_result_data.scores().size() == slice_topKs[i] * slice_nqs[i]);
// assert(search_result_data.ids().int_id().data_size() == slice_topKs[i] * slice_nqs[i]);
// assert(search_result_data.top_k() == topK);
// assert(search_result_data.num_queries() == req_sizes[i]);
// assert(search_result_data.scores().size() == topK * req_sizes[i]);
// assert(search_result_data.ids().int_id().data_size() == topK * req_sizes[i]);
}
DeleteSearchResultDataBlobs(cSearchResultData);
......@@ -1378,6 +1369,8 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) {
}
TEST(CApiTest, ReduceSearchWithExpr) {
testReduceSearchWithExpr(2, 1, 1);
testReduceSearchWithExpr(2, 10, 10);
testReduceSearchWithExpr(100, 1, 1);
testReduceSearchWithExpr(100, 10, 10);
testReduceSearchWithExpr(10000, 1, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册