From f8cdf0786ba1abf51bda40dccc1fb8ed0c68e02d Mon Sep 17 00:00:00 2001 From: starlord Date: Thu, 4 Jul 2019 16:58:40 +0800 Subject: [PATCH] add uiittest for merge result functions Former-commit-id: 0ad3ac4b08e06a1c64249aea05f0c62efa3fe57a --- cpp/CHANGELOG.md | 1 + cpp/src/db/scheduler/context/SearchContext.h | 4 +- cpp/src/db/scheduler/task/SearchTask.cpp | 215 ++++++++++--------- cpp/src/db/scheduler/task/SearchTask.h | 14 ++ cpp/unittest/db/search_test.cpp | 162 ++++++++++++++ 5 files changed, 294 insertions(+), 102 deletions(-) create mode 100644 cpp/unittest/db/search_test.cpp diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index 630b86d7..42b11ab9 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -8,6 +8,7 @@ Please mark all change in change log and use the ticket from JIRA. ## Bug ## Improvement +- MS-156 - Add unittest for merge result functions ## New Feature diff --git a/cpp/src/db/scheduler/context/SearchContext.h b/cpp/src/db/scheduler/context/SearchContext.h index 1997b807..e81622eb 100644 --- a/cpp/src/db/scheduler/context/SearchContext.h +++ b/cpp/src/db/scheduler/context/SearchContext.h @@ -32,8 +32,8 @@ public: using Id2IndexMap = std::unordered_map; const Id2IndexMap& GetIndexMap() const { return map_index_files_; } - using Id2ScoreMap = std::vector>; - using ResultSet = std::vector; + using Id2DistanceMap = std::vector>; + using ResultSet = std::vector; const ResultSet& GetResult() const { return result_; } ResultSet& GetResult() { return result_; } diff --git a/cpp/src/db/scheduler/task/SearchTask.cpp b/cpp/src/db/scheduler/task/SearchTask.cpp index 2bfac90e..708bcc87 100644 --- a/cpp/src/db/scheduler/task/SearchTask.cpp +++ b/cpp/src/db/scheduler/task/SearchTask.cpp @@ -13,104 +13,6 @@ namespace milvus { namespace engine { namespace { -void ClusterResult(const std::vector &output_ids, - const std::vector &output_distence, - uint64_t nq, - uint64_t topk, - SearchContext::ResultSet &result_set) { - result_set.clear(); - result_set.reserve(nq); - for (auto i = 0; i < nq; i++) { - SearchContext::Id2ScoreMap id_score; - id_score.reserve(topk); - for (auto k = 0; k < topk; k++) { - uint64_t index = i * topk + k; - if(output_ids[index] < 0) { - continue; - } - id_score.push_back(std::make_pair(output_ids[index], output_distence[index])); - } - result_set.emplace_back(id_score); - } -} - -void MergeResult(SearchContext::Id2ScoreMap &score_src, - SearchContext::Id2ScoreMap &score_target, - uint64_t topk) { - //Note: the score_src and score_target are already arranged by score in ascending order - if(score_src.empty()) { - return; - } - - if(score_target.empty()) { - score_target.swap(score_src); - return; - } - - size_t src_count = score_src.size(); - size_t target_count = score_target.size(); - SearchContext::Id2ScoreMap score_merged; - score_merged.reserve(topk); - size_t src_index = 0, target_index = 0; - while(true) { - //all score_src items are merged, if score_merged.size() still less than topk - //move items from score_target to score_merged until score_merged.size() equal topk - if(src_index >= src_count) { - for(size_t i = target_index; i < target_count && score_merged.size() < topk; ++i) { - score_merged.push_back(score_target[i]); - } - break; - } - - //all score_target items are merged, if score_merged.size() still less than topk - //move items from score_src to score_merged until score_merged.size() equal topk - if(target_index >= target_count) { - for(size_t i = src_index; i < src_count && score_merged.size() < topk; ++i) { - score_merged.push_back(score_src[i]); - } - break; - } - - //compare score, put smallest score to score_merged one by one - auto& src_pair = score_src[src_index]; - auto& target_pair = score_target[target_index]; - if(src_pair.second > target_pair.second) { - score_merged.push_back(target_pair); - target_index++; - } else { - score_merged.push_back(src_pair); - src_index++; - } - - //score_merged.size() already equal topk - if(score_merged.size() >= topk) { - break; - } - } - - score_target.swap(score_merged); -} - -void TopkResult(SearchContext::ResultSet &result_src, - uint64_t topk, - SearchContext::ResultSet &result_target) { - if (result_target.empty()) { - result_target.swap(result_src); - return; - } - - if (result_src.size() != result_target.size()) { - SERVER_LOG_ERROR << "Invalid result set"; - return; - } - - for (size_t i = 0; i < result_src.size(); i++) { - SearchContext::Id2ScoreMap &score_src = result_src[i]; - SearchContext::Id2ScoreMap &score_target = result_target[i]; - MergeResult(score_src, score_target, topk); - } -} - void CollectDurationMetrics(int index_type, double total_time) { switch(index_type) { case meta::TableFileSchema::RAW: { @@ -165,11 +67,11 @@ std::shared_ptr SearchTask::Execute() { //step 3: cluster result SearchContext::ResultSet result_set; auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk(); - ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set); + SearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set); rc.Record("cluster result"); //step 4: pick up topk result - TopkResult(result_set, inner_k, context->GetResult()); + SearchTask::TopkResult(result_set, inner_k, context->GetResult()); rc.Record("reduce topk"); } catch (std::exception& ex) { @@ -191,6 +93,119 @@ std::shared_ptr SearchTask::Execute() { return nullptr; } +Status SearchTask::ClusterResult(const std::vector &output_ids, + const std::vector &output_distence, + uint64_t nq, + uint64_t topk, + SearchContext::ResultSet &result_set) { + if(output_ids.size() != nq*topk || output_distence.size() != nq*topk) { + std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + + " distance array size: " + std::to_string(output_distence.size()); + SERVER_LOG_ERROR << msg; + return Status::Error(msg); + } + + result_set.clear(); + result_set.reserve(nq); + for (auto i = 0; i < nq; i++) { + SearchContext::Id2DistanceMap id_distance; + id_distance.reserve(topk); + for (auto k = 0; k < topk; k++) { + uint64_t index = i * topk + k; + if(output_ids[index] < 0) { + continue; + } + id_distance.push_back(std::make_pair(output_ids[index], output_distence[index])); + } + result_set.emplace_back(id_distance); + } + + return Status::OK(); +} + +Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src, + SearchContext::Id2DistanceMap &distance_target, + uint64_t topk) { + //Note: the score_src and score_target are already arranged by score in ascending order + if(distance_src.empty()) { + SERVER_LOG_WARNING << "Empty distance source array"; + return Status::OK(); + } + + if(distance_target.empty()) { + distance_target.swap(distance_src); + return Status::OK(); + } + + size_t src_count = distance_src.size(); + size_t target_count = distance_target.size(); + SearchContext::Id2DistanceMap distance_merged; + distance_merged.reserve(topk); + size_t src_index = 0, target_index = 0; + while(true) { + //all score_src items are merged, if score_merged.size() still less than topk + //move items from score_target to score_merged until score_merged.size() equal topk + if(src_index >= src_count) { + for(size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) { + distance_merged.push_back(distance_target[i]); + } + break; + } + + //all score_target items are merged, if score_merged.size() still less than topk + //move items from score_src to score_merged until score_merged.size() equal topk + if(target_index >= target_count) { + for(size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) { + distance_merged.push_back(distance_src[i]); + } + break; + } + + //compare score, put smallest score to score_merged one by one + auto& src_pair = distance_src[src_index]; + auto& target_pair = distance_target[target_index]; + if(src_pair.second > target_pair.second) { + distance_merged.push_back(target_pair); + target_index++; + } else { + distance_merged.push_back(src_pair); + src_index++; + } + + //score_merged.size() already equal topk + if(distance_merged.size() >= topk) { + break; + } + } + + distance_target.swap(distance_merged); + + return Status::OK(); +} + +Status SearchTask::TopkResult(SearchContext::ResultSet &result_src, + uint64_t topk, + SearchContext::ResultSet &result_target) { + if (result_target.empty()) { + result_target.swap(result_src); + return Status::OK(); + } + + if (result_src.size() != result_target.size()) { + std::string msg = "Invalid result set size"; + SERVER_LOG_ERROR << msg; + return Status::Error(msg); + } + + for (size_t i = 0; i < result_src.size(); i++) { + SearchContext::Id2DistanceMap &score_src = result_src[i]; + SearchContext::Id2DistanceMap &score_target = result_target[i]; + SearchTask::MergeResult(score_src, score_target, topk); + } + + return Status::OK(); +} + } } } diff --git a/cpp/src/db/scheduler/task/SearchTask.h b/cpp/src/db/scheduler/task/SearchTask.h index 0b3a236c..e4f0d872 100644 --- a/cpp/src/db/scheduler/task/SearchTask.h +++ b/cpp/src/db/scheduler/task/SearchTask.h @@ -19,6 +19,20 @@ public: virtual std::shared_ptr Execute() override; + static Status ClusterResult(const std::vector &output_ids, + const std::vector &output_distence, + uint64_t nq, + uint64_t topk, + SearchContext::ResultSet &result_set); + + static Status MergeResult(SearchContext::Id2DistanceMap &distance_src, + SearchContext::Id2DistanceMap &distance_target, + uint64_t topk); + + static Status TopkResult(SearchContext::ResultSet &result_src, + uint64_t topk, + SearchContext::ResultSet &result_target); + public: size_t index_id_ = 0; int index_type_ = 0; //for metrics diff --git a/cpp/unittest/db/search_test.cpp b/cpp/unittest/db/search_test.cpp new file mode 100644 index 00000000..db10bcba --- /dev/null +++ b/cpp/unittest/db/search_test.cpp @@ -0,0 +1,162 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +// Unauthorized copying of this file, via any medium is strictly prohibited. +// Proprietary and confidential. +//////////////////////////////////////////////////////////////////////////////// +#include + +#include "db/scheduler/task/SearchTask.h" + +#include + +using namespace zilliz::milvus; + +namespace { + +static constexpr uint64_t NQ = 15; +static constexpr uint64_t TOP_K = 64; + +void BuildResult(uint64_t nq, + uint64_t top_k, + std::vector &output_ids, + std::vector &output_distence) { + output_ids.clear(); + output_ids.resize(nq*top_k); + output_distence.clear(); + output_distence.resize(nq*top_k); + + for(uint64_t i = 0; i < nq; i++) { + for(uint64_t j = 0; j < top_k; j++) { + output_ids[i * top_k + j] = (long)(drand48()*100000); + output_distence[i * top_k + j] = j + drand48(); + } + } +} + +void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1, + const engine::SearchContext::Id2DistanceMap& src_2, + const engine::SearchContext::Id2DistanceMap& target) { + for(uint64_t i = 0; i < target.size() - 1; i++) { + ASSERT_LE(target[i].second, target[i + 1].second); + } + + using ID2DistMap = std::map; + ID2DistMap src_map_1, src_map_2; + for(const auto& pair : src_1) { + src_map_1.insert(pair); + } + for(const auto& pair : src_2) { + src_map_2.insert(pair); + } + + for(const auto& pair : target) { + ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end()); + + float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first]; + ASSERT_LT(fabs(pair.second - dist), std::numeric_limits::epsilon()); + } +} + +} + +TEST(DBSearchTest, TOPK_TEST) { + std::vector target_ids; + std::vector target_distence; + engine::SearchContext::ResultSet src_result; + auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result); + ASSERT_FALSE(status.ok()); + ASSERT_TRUE(src_result.empty()); + + BuildResult(NQ, TOP_K, target_ids, target_distence); + status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(src_result.size(), NQ); + + engine::SearchContext::ResultSet target_result; + status = engine::SearchTask::TopkResult(target_result, TOP_K, target_result); + ASSERT_TRUE(status.ok()); + + status = engine::SearchTask::TopkResult(target_result, TOP_K, src_result); + ASSERT_FALSE(status.ok()); + + status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result); + ASSERT_TRUE(status.ok()); + ASSERT_TRUE(src_result.empty()); + ASSERT_EQ(target_result.size(), NQ); + + std::vector src_ids; + std::vector src_distence; + uint64_t wrong_topk = TOP_K - 10; + BuildResult(NQ, wrong_topk, src_ids, src_distence); + + status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result); + ASSERT_TRUE(status.ok()); + + status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result); + ASSERT_TRUE(status.ok()); + for(uint64_t i = 0; i < NQ; i++) { + ASSERT_EQ(target_result[i].size(), TOP_K); + } + + wrong_topk = TOP_K + 10; + BuildResult(NQ, wrong_topk, src_ids, src_distence); + + status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result); + ASSERT_TRUE(status.ok()); + for(uint64_t i = 0; i < NQ; i++) { + ASSERT_EQ(target_result[i].size(), TOP_K); + } +} + +TEST(DBSearchTest, MERGE_TEST) { + std::vector target_ids; + std::vector target_distence; + std::vector src_ids; + std::vector src_distence; + engine::SearchContext::ResultSet src_result, target_result; + + uint64_t src_count = 5, target_count = 8; + BuildResult(1, src_count, src_ids, src_distence); + BuildResult(1, target_count, target_ids, target_distence); + auto status = engine::SearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result); + ASSERT_TRUE(status.ok()); + status = engine::SearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result); + ASSERT_TRUE(status.ok()); + + { + engine::SearchContext::Id2DistanceMap src = src_result[0]; + engine::SearchContext::Id2DistanceMap target = target_result[0]; + status = engine::SearchTask::MergeResult(src, target, 10); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(target.size(), 10); + CheckResult(src_result[0], target_result[0], target); + } + + { + engine::SearchContext::Id2DistanceMap src = src_result[0]; + engine::SearchContext::Id2DistanceMap target; + status = engine::SearchTask::MergeResult(src, target, 10); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(target.size(), src_count); + ASSERT_TRUE(src.empty()); + CheckResult(src_result[0], target_result[0], target); + } + + { + engine::SearchContext::Id2DistanceMap src = src_result[0]; + engine::SearchContext::Id2DistanceMap target = target_result[0]; + status = engine::SearchTask::MergeResult(src, target, 30); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(target.size(), src_count + target_count); + CheckResult(src_result[0], target_result[0], target); + } + + { + engine::SearchContext::Id2DistanceMap target = src_result[0]; + engine::SearchContext::Id2DistanceMap src = target_result[0]; + status = engine::SearchTask::MergeResult(src, target, 30); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(target.size(), src_count + target_count); + CheckResult(src_result[0], target_result[0], target); + } +} \ No newline at end of file -- GitLab