diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index bd985db85bea950c54c06367fb1f51bba6d1253e..21236db11de8e7950be8465887c09ef9adaf7b03 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -73,12 +73,13 @@ struct SearchResult { int64_t num_queries_; int64_t topk_; std::vector result_distances_; + std::vector internal_seg_offsets_; public: // TODO(gexi): utilize these field void* segment_; - std::vector internal_seg_offsets_; std::vector result_offsets_; + std::vector primary_keys_; std::vector> row_data_; }; diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index 8886caab046a4329d236eacb86b06ff991bae7cf..a2b4cc6e435d9a33fa5a408eec29ab730566ff92 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -14,6 +14,35 @@ namespace milvus::segcore { class Naive; +void +SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const { + std::shared_lock lck(mutex_); + AssertInfo(plan, "empty plan"); + auto size = results.result_distances_.size(); + AssertInfo(results.internal_seg_offsets_.size() == size, + "Size of result distances is not equal to size of segment offsets"); + Assert(results.primary_keys_.size() == 0); + + results.primary_keys_.resize(size); + + auto element_sizeof = sizeof(int64_t); + + aligned_vector blob(size * element_sizeof); + if (plan->schema_.get_is_auto_id()) { + bulk_subscript(SystemFieldType::RowId, results.internal_seg_offsets_.data(), size, blob.data()); + } else { + auto key_offset_opt = get_schema().get_primary_key_offset(); + AssertInfo(key_offset_opt.has_value(), "Cannot get primary key offset from schema"); + auto key_offset = key_offset_opt.value(); + AssertInfo(get_schema()[key_offset].get_data_type() == DataType::INT64, "Primary key field is not INT64 type"); + bulk_subscript(key_offset, results.internal_seg_offsets_.data(), size, blob.data()); + } + + for (int64_t i = 0; i < size; ++i) { + results.primary_keys_[i] = *(int64_t*)(blob.data() + element_sizeof * i); + } +} + void SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult& results) const { std::shared_lock lck(mutex_); @@ -21,10 +50,8 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult& auto size = results.result_distances_.size(); AssertInfo(results.internal_seg_offsets_.size() == size, "Size of result distances is not equal to size of segment offsets"); - // Assert(results.result_offsets_.size() == size); Assert(results.row_data_.size() == 0); - // std::vector row_ids(size); std::vector element_sizeofs; std::vector> blobs; @@ -45,7 +72,7 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult& element_sizeofs.push_back(sizeof(int64_t)); } - // fill other entries + // fill other entries except primary key for (auto field_offset : plan->target_entries_) { auto& field_meta = get_schema()[field_offset]; auto element_sizeof = field_meta.get_sizeof(); diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index 6d63488d691b099d095f4fdfb0addf010b9ab914..8ba73646ec795eb978bf76daa14f8efd5b39a824 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -28,11 +28,12 @@ namespace milvus::segcore { -// common interface of SegmentSealed and SegmentGrowing -// used by C API +// common interface of SegmentSealed and SegmentGrowing used by C API class SegmentInterface { public: - // fill results according to target_entries in plan + virtual void + FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const = 0; + virtual void FillTargetEntry(const query::Plan* plan, SearchResult& results) const = 0; @@ -82,6 +83,9 @@ class SegmentInternalInterface : public SegmentInterface { const query::PlaceholderGroup& placeholder_group, Timestamp timestamp) const override; + void + FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const override; + void FillTargetEntry(const query::Plan* plan, SearchResult& results) const override; diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index 877d90569ca0e6d05f364d389c33eb3e32457b7f..4fce0aaf4d02b47b0bf93ab505660782f2bb25d7 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -10,14 +10,16 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include -#include +#include +#include "common/Types.h" +#include "exceptions/EasyAssert.h" +#include "log/Log.h" #include "query/Plan.h" #include "segcore/reduce_c.h" #include "segcore/Reduce.h" #include "segcore/ReduceStructure.h" #include "segcore/SegmentInterface.h" -#include "common/Types.h" #include "pb/milvus.pb.h" using SearchResult = milvus::SearchResult; @@ -69,6 +71,8 @@ GetResultData(std::vector>& search_records, } int64_t loc_offset = query_offset; AssertInfo(topk > 0, "topk must greater than 0"); + +#if 0 for (int i = 0; i < topk; ++i) { result_pairs[0].reset_distance(); std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); @@ -77,6 +81,42 @@ GetResultData(std::vector>& search_records, result_pair.search_result_->result_offsets_.push_back(loc_offset++); search_records[index].push_back(result_pair.offset_++); } +#else + float prev_dis = MAXFLOAT; + std::unordered_set prev_pk_set; + prev_pk_set.insert(-1); + while (loc_offset - query_offset < topk) { + result_pairs[0].reset_distance(); + std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); + auto& result_pair = result_pairs[0]; + auto index = result_pair.index_; + int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_]; + float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_]; + // remove duplicates + if (curr_pk == -1 || curr_dis != prev_dis) { + result_pair.search_result_->result_offsets_.push_back(loc_offset++); + search_records[index].push_back(result_pair.offset_++); + prev_dis = curr_dis; + prev_pk_set.clear(); + prev_pk_set.insert(curr_pk); + } else { + // To handle this case: + // e1: [100, 0.99] + // e2: [101, 0.99] ==> not duplicated, should keep + // e3: [100, 0.99] ==> duplicated, should remove + if (prev_pk_set.count(curr_pk) == 0) { + result_pair.search_result_->result_offsets_.push_back(loc_offset++); + search_records[index].push_back(result_pair.offset_++); + // prev_pk_set keeps all primary keys with same distance + prev_pk_set.insert(curr_pk); + } else { + // the entity with same distance and same primary key must be duplicated + result_pair.offset_++; + LOG_SEGCORE_DEBUG_ << "skip duplicated search result, primary key " << curr_pk; + } + } + } +#endif } void @@ -90,17 +130,21 @@ ResetSearchResult(std::vector>& search_records, std::vector continue; } + std::vector primary_keys; std::vector result_distances; std::vector internal_seg_offsets; for (int j = 0; j < search_records[i].size(); j++) { auto& offset = search_records[i][j]; + auto primary_key = search_result->primary_keys_[offset]; auto distance = search_result->result_distances_[offset]; auto internal_seg_offset = search_result->internal_seg_offsets_[offset]; + primary_keys.push_back(primary_key); result_distances.push_back(distance); internal_seg_offsets.push_back(internal_seg_offset); } + search_result->primary_keys_ = primary_keys; search_result->result_distances_ = result_distances; search_result->internal_seg_offsets_ = internal_seg_offsets; } @@ -118,13 +162,19 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul auto num_queries = search_results[0]->num_queries_; std::vector> search_records(num_segments); + // get primary keys for duplicates removal + for (auto& search_result : search_results) { + auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_); + segment->FillPrimaryKeys(plan, *search_result); + } + for (int i = 0; i < num_queries; ++i) { GetResultData(search_records, search_results, i, topk); } ResetSearchResult(search_records, search_results); - for (int i = 0; i < num_segments; ++i) { - auto search_result = search_results[i]; + // fill in other entities + for (auto& search_result : search_results) { auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_); segment->FillTargetEntry(plan, *search_result); } diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 3990d24e6b5a2efe27984ed2e2a79233335dc547..bbc7dd7d28baec74b4e156d6b564eac28fee5e7f 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -9,12 +9,13 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License -#include -#include -#include #include #include #include +#include +#include +#include +#include #include "common/LoadInfo.h" #include "index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h" @@ -448,6 +449,121 @@ TEST(CApiTest, MergeInto) { ASSERT_EQ(distance[1], 5); } +void +CheckSearchResultDuplicate(const std::vector& results) { + auto sr = (SearchResult*)results[0]; + auto topk = sr->topk_; + auto num_queries = sr->num_queries_; + + std::unordered_set pk_set; + std::unordered_set distance_set; + for (int i = 0; i < results.size(); i++) { + auto search_result = (SearchResult*)results[i]; + auto size = search_result->result_offsets_.size(); + for (int j = 0; j < size; j++) { + auto ret = pk_set.insert(search_result->primary_keys_[j]); + // std::cout << j << ": " << ret.second << " " + // << search_result->primary_keys_[j] << " " + // << search_result->result_distances_[j] << std::endl; + distance_set.insert(search_result->result_distances_[j]); + } + } + std::cout << pk_set.size() << " " << distance_set.size() << " " << topk * num_queries << std::endl; + // TODO: find 1 duplicated result (pk = 10345), need check + assert(pk_set.size() == topk * num_queries - 1); +} + +TEST(CApiTest, ReduceRemoveDuplicates) { + auto collection = NewCollection(get_default_schema_config()); + auto segment = NewSegment(collection, 0, Growing); + + int N = 10000; + auto [raw_data, timestamps, uids] = generate_data(N); + auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); + + int64_t offset; + PreInsert(segment, N, &offset); + auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); + assert(ins_res.error_code == Success); + + const char* dsl_string = R"( + { + "bool": { + "vector": { + "fakevec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 10, + "round_decimal": 3 + } + } + } + })"; + + int num_queries = 10; + auto blob = generate_query_data(num_queries); + + void* plan = nullptr; + auto status = CreateSearchPlan(collection, dsl_string, &plan); + assert(status.error_code == Success); + + void* placeholderGroup = nullptr; + status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup); + assert(status.error_code == Success); + + std::vector placeholderGroups; + placeholderGroups.push_back(placeholderGroup); + timestamps.clear(); + timestamps.push_back(1); + + { + std::vector results; + CSearchResult res1, res2; + status = Search(segment, plan, placeholderGroup, timestamps[0], &res1); + assert(status.error_code == Success); + status = Search(segment, plan, placeholderGroup, timestamps[0], &res2); + assert(status.error_code == Success); + results.push_back(res1); + results.push_back(res2); + + status = ReduceSearchResultsAndFillData(plan, results.data(), results.size()); + assert(status.error_code == Success); + CheckSearchResultDuplicate(results); + + DeleteSearchResult(res1); + DeleteSearchResult(res2); + } + { + std::vector results; + CSearchResult res1, res2, res3; + status = Search(segment, plan, placeholderGroup, timestamps[0], &res1); + assert(status.error_code == Success); + status = Search(segment, plan, placeholderGroup, timestamps[0], &res2); + assert(status.error_code == Success); + status = Search(segment, plan, placeholderGroup, timestamps[0], &res3); + assert(status.error_code == Success); + results.push_back(res1); + results.push_back(res2); + results.push_back(res3); + + status = ReduceSearchResultsAndFillData(plan, results.data(), results.size()); + assert(status.error_code == Success); + CheckSearchResultDuplicate(results); + + DeleteSearchResult(res1); + DeleteSearchResult(res2); + DeleteSearchResult(res3); + } + + DeleteSearchPlan(plan); + DeletePlaceholderGroup(placeholderGroup); + DeleteCollection(collection); + DeleteSegment(segment); +} + TEST(CApiTest, Reduce) { auto collection = NewCollection(get_default_schema_config()); auto segment = NewSegment(collection, 0, Growing);