提交 6207e830 编写于 作者: S starlord 提交者: jinhai

add uiittest for merge result functions


Former-commit-id: 071b7cd18d8acbd6bd6bdceb7dae9a7cf1d1a86c
上级 b4f1e9cb
......@@ -8,6 +8,7 @@ Please mark all change in change log and use the ticket from JIRA.
## Bug
## Improvement
- MS-156 - Add unittest for merge result functions
## New Feature
......
......@@ -32,8 +32,8 @@ public:
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
const Id2IndexMap& GetIndexMap() const { return map_index_files_; }
using Id2ScoreMap = std::vector<std::pair<int64_t, double>>;
using ResultSet = std::vector<Id2ScoreMap>;
using Id2DistanceMap = std::vector<std::pair<int64_t, double>>;
using ResultSet = std::vector<Id2DistanceMap>;
const ResultSet& GetResult() const { return result_; }
ResultSet& GetResult() { return result_; }
......
......@@ -13,104 +13,6 @@ namespace milvus {
namespace engine {
namespace {
void ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
uint64_t nq,
uint64_t topk,
SearchContext::ResultSet &result_set) {
result_set.clear();
result_set.reserve(nq);
for (auto i = 0; i < nq; i++) {
SearchContext::Id2ScoreMap id_score;
id_score.reserve(topk);
for (auto k = 0; k < topk; k++) {
uint64_t index = i * topk + k;
if(output_ids[index] < 0) {
continue;
}
id_score.push_back(std::make_pair(output_ids[index], output_distence[index]));
}
result_set.emplace_back(id_score);
}
}
void MergeResult(SearchContext::Id2ScoreMap &score_src,
SearchContext::Id2ScoreMap &score_target,
uint64_t topk) {
//Note: the score_src and score_target are already arranged by score in ascending order
if(score_src.empty()) {
return;
}
if(score_target.empty()) {
score_target.swap(score_src);
return;
}
size_t src_count = score_src.size();
size_t target_count = score_target.size();
SearchContext::Id2ScoreMap score_merged;
score_merged.reserve(topk);
size_t src_index = 0, target_index = 0;
while(true) {
//all score_src items are merged, if score_merged.size() still less than topk
//move items from score_target to score_merged until score_merged.size() equal topk
if(src_index >= src_count) {
for(size_t i = target_index; i < target_count && score_merged.size() < topk; ++i) {
score_merged.push_back(score_target[i]);
}
break;
}
//all score_target items are merged, if score_merged.size() still less than topk
//move items from score_src to score_merged until score_merged.size() equal topk
if(target_index >= target_count) {
for(size_t i = src_index; i < src_count && score_merged.size() < topk; ++i) {
score_merged.push_back(score_src[i]);
}
break;
}
//compare score, put smallest score to score_merged one by one
auto& src_pair = score_src[src_index];
auto& target_pair = score_target[target_index];
if(src_pair.second > target_pair.second) {
score_merged.push_back(target_pair);
target_index++;
} else {
score_merged.push_back(src_pair);
src_index++;
}
//score_merged.size() already equal topk
if(score_merged.size() >= topk) {
break;
}
}
score_target.swap(score_merged);
}
void TopkResult(SearchContext::ResultSet &result_src,
uint64_t topk,
SearchContext::ResultSet &result_target) {
if (result_target.empty()) {
result_target.swap(result_src);
return;
}
if (result_src.size() != result_target.size()) {
SERVER_LOG_ERROR << "Invalid result set";
return;
}
for (size_t i = 0; i < result_src.size(); i++) {
SearchContext::Id2ScoreMap &score_src = result_src[i];
SearchContext::Id2ScoreMap &score_target = result_target[i];
MergeResult(score_src, score_target, topk);
}
}
void CollectDurationMetrics(int index_type, double total_time) {
switch(index_type) {
case meta::TableFileSchema::RAW: {
......@@ -164,11 +66,12 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
//step 3: cluster result
SearchContext::ResultSet result_set;
ClusterResult(output_ids, output_distence, context->nq(), inner_k, result_set);
auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
SearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
rc.Record("cluster result");
//step 4: pick up topk result
TopkResult(result_set, inner_k, context->GetResult());
SearchTask::TopkResult(result_set, inner_k, context->GetResult());
rc.Record("reduce topk");
} catch (std::exception& ex) {
......@@ -190,6 +93,119 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
return nullptr;
}
Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
uint64_t nq,
uint64_t topk,
SearchContext::ResultSet &result_set) {
if(output_ids.size() != nq*topk || output_distence.size() != nq*topk) {
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
" distance array size: " + std::to_string(output_distence.size());
SERVER_LOG_ERROR << msg;
return Status::Error(msg);
}
result_set.clear();
result_set.reserve(nq);
for (auto i = 0; i < nq; i++) {
SearchContext::Id2DistanceMap id_distance;
id_distance.reserve(topk);
for (auto k = 0; k < topk; k++) {
uint64_t index = i * topk + k;
if(output_ids[index] < 0) {
continue;
}
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
}
result_set.emplace_back(id_distance);
}
return Status::OK();
}
Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
SearchContext::Id2DistanceMap &distance_target,
uint64_t topk) {
//Note: the score_src and score_target are already arranged by score in ascending order
if(distance_src.empty()) {
SERVER_LOG_WARNING << "Empty distance source array";
return Status::OK();
}
if(distance_target.empty()) {
distance_target.swap(distance_src);
return Status::OK();
}
size_t src_count = distance_src.size();
size_t target_count = distance_target.size();
SearchContext::Id2DistanceMap distance_merged;
distance_merged.reserve(topk);
size_t src_index = 0, target_index = 0;
while(true) {
//all score_src items are merged, if score_merged.size() still less than topk
//move items from score_target to score_merged until score_merged.size() equal topk
if(src_index >= src_count) {
for(size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
distance_merged.push_back(distance_target[i]);
}
break;
}
//all score_target items are merged, if score_merged.size() still less than topk
//move items from score_src to score_merged until score_merged.size() equal topk
if(target_index >= target_count) {
for(size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
distance_merged.push_back(distance_src[i]);
}
break;
}
//compare score, put smallest score to score_merged one by one
auto& src_pair = distance_src[src_index];
auto& target_pair = distance_target[target_index];
if(src_pair.second > target_pair.second) {
distance_merged.push_back(target_pair);
target_index++;
} else {
distance_merged.push_back(src_pair);
src_index++;
}
//score_merged.size() already equal topk
if(distance_merged.size() >= topk) {
break;
}
}
distance_target.swap(distance_merged);
return Status::OK();
}
Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
uint64_t topk,
SearchContext::ResultSet &result_target) {
if (result_target.empty()) {
result_target.swap(result_src);
return Status::OK();
}
if (result_src.size() != result_target.size()) {
std::string msg = "Invalid result set size";
SERVER_LOG_ERROR << msg;
return Status::Error(msg);
}
for (size_t i = 0; i < result_src.size(); i++) {
SearchContext::Id2DistanceMap &score_src = result_src[i];
SearchContext::Id2DistanceMap &score_target = result_target[i];
SearchTask::MergeResult(score_src, score_target, topk);
}
return Status::OK();
}
}
}
}
......@@ -19,6 +19,20 @@ public:
virtual std::shared_ptr<IScheduleTask> Execute() override;
static Status ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
uint64_t nq,
uint64_t topk,
SearchContext::ResultSet &result_set);
static Status MergeResult(SearchContext::Id2DistanceMap &distance_src,
SearchContext::Id2DistanceMap &distance_target,
uint64_t topk);
static Status TopkResult(SearchContext::ResultSet &result_src,
uint64_t topk,
SearchContext::ResultSet &result_target);
public:
size_t index_id_ = 0;
int index_type_ = 0; //for metrics
......
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include <gtest/gtest.h>
#include "db/scheduler/task/SearchTask.h"
#include <vector>
using namespace zilliz::milvus;
namespace {
static constexpr uint64_t NQ = 15;
static constexpr uint64_t TOP_K = 64;
void BuildResult(uint64_t nq,
uint64_t top_k,
std::vector<long> &output_ids,
std::vector<float> &output_distence) {
output_ids.clear();
output_ids.resize(nq*top_k);
output_distence.clear();
output_distence.resize(nq*top_k);
for(uint64_t i = 0; i < nq; i++) {
for(uint64_t j = 0; j < top_k; j++) {
output_ids[i * top_k + j] = (long)(drand48()*100000);
output_distence[i * top_k + j] = j + drand48();
}
}
}
void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
const engine::SearchContext::Id2DistanceMap& src_2,
const engine::SearchContext::Id2DistanceMap& target) {
for(uint64_t i = 0; i < target.size() - 1; i++) {
ASSERT_LE(target[i].second, target[i + 1].second);
}
using ID2DistMap = std::map<long, float>;
ID2DistMap src_map_1, src_map_2;
for(const auto& pair : src_1) {
src_map_1.insert(pair);
}
for(const auto& pair : src_2) {
src_map_2.insert(pair);
}
for(const auto& pair : target) {
ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end());
float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first];
ASSERT_LT(fabs(pair.second - dist), std::numeric_limits<float>::epsilon());
}
}
}
TEST(DBSearchTest, TOPK_TEST) {
std::vector<long> target_ids;
std::vector<float> target_distence;
engine::SearchContext::ResultSet src_result;
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
ASSERT_FALSE(status.ok());
ASSERT_TRUE(src_result.empty());
BuildResult(NQ, TOP_K, target_ids, target_distence);
status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
ASSERT_TRUE(status.ok());
ASSERT_EQ(src_result.size(), NQ);
engine::SearchContext::ResultSet target_result;
status = engine::SearchTask::TopkResult(target_result, TOP_K, target_result);
ASSERT_TRUE(status.ok());
status = engine::SearchTask::TopkResult(target_result, TOP_K, src_result);
ASSERT_FALSE(status.ok());
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
ASSERT_TRUE(status.ok());
ASSERT_TRUE(src_result.empty());
ASSERT_EQ(target_result.size(), NQ);
std::vector<long> src_ids;
std::vector<float> src_distence;
uint64_t wrong_topk = TOP_K - 10;
BuildResult(NQ, wrong_topk, src_ids, src_distence);
status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
ASSERT_TRUE(status.ok());
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
ASSERT_TRUE(status.ok());
for(uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K);
}
wrong_topk = TOP_K + 10;
BuildResult(NQ, wrong_topk, src_ids, src_distence);
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
ASSERT_TRUE(status.ok());
for(uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K);
}
}
TEST(DBSearchTest, MERGE_TEST) {
std::vector<long> target_ids;
std::vector<float> target_distence;
std::vector<long> src_ids;
std::vector<float> src_distence;
engine::SearchContext::ResultSet src_result, target_result;
uint64_t src_count = 5, target_count = 8;
BuildResult(1, src_count, src_ids, src_distence);
BuildResult(1, target_count, target_ids, target_distence);
auto status = engine::SearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
ASSERT_TRUE(status.ok());
status = engine::SearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
ASSERT_TRUE(status.ok());
{
engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 10);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), 10);
CheckResult(src_result[0], target_result[0], target);
}
{
engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target;
status = engine::SearchTask::MergeResult(src, target, 10);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count);
ASSERT_TRUE(src.empty());
CheckResult(src_result[0], target_result[0], target);
}
{
engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 30);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target);
}
{
engine::SearchContext::Id2DistanceMap target = src_result[0];
engine::SearchContext::Id2DistanceMap src = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 30);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target);
}
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册