提交 a1eda477 编写于 作者: J jinhai

Merge branch 'branch-0.3.0' into 'branch-0.3.0'

MS-52 - Return search score

See merge request megasearch/vecwise_engine!65

Former-commit-id: 29bd72ba9d7b3490046f3c8e6faf69a43bd10e1c
...@@ -12,7 +12,8 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -12,7 +12,8 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-57 - Implement index load/search pipeline - MS-57 - Implement index load/search pipeline
- MS-56 - Add version information when server is started - MS-56 - Add version information when server is started
- Ms-64 - Different table can have different index type - MS-64 - Different table can have different index type
- MS-52 - Return search score
## Task ## Task
......
...@@ -240,7 +240,7 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq, ...@@ -240,7 +240,7 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq,
int inner_k = dis.size() < k ? dis.size() : k; int inner_k = dis.size() < k ? dis.size() : k;
for (int i = 0; i < inner_k; ++i) { for (int i = 0; i < inner_k; ++i) {
res.emplace_back(nns[output_ids[i]]); // mapping res.emplace_back(std::make_pair(nns[output_ids[i]], output_distence[i])); // mapping
} }
results.push_back(res); // append to result list results.push_back(res); // append to result list
res.clear(); res.clear();
...@@ -267,6 +267,8 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq, ...@@ -267,6 +267,8 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq,
Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq, Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq,
const float* vectors, const meta::DatesT& dates, QueryResults& results) { const float* vectors, const meta::DatesT& dates, QueryResults& results) {
//step 1: get files to search
meta::DatePartionedTableFilesSchema files; meta::DatePartionedTableFilesSchema files;
auto status = pMeta_->FilesToSearch(table_id, dates, files); auto status = pMeta_->FilesToSearch(table_id, dates, files);
if (!status.ok()) { return status; } if (!status.ok()) { return status; }
...@@ -282,18 +284,15 @@ Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq, ...@@ -282,18 +284,15 @@ Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq,
} }
} }
//step 2: put search task to scheduler
SearchScheduler& scheduler = SearchScheduler::GetInstance(); SearchScheduler& scheduler = SearchScheduler::GetInstance();
scheduler.ScheduleSearchTask(context); scheduler.ScheduleSearchTask(context);
context->WaitResult(); context->WaitResult();
//step 3: construct results
auto& context_result = context->GetResult(); auto& context_result = context->GetResult();
for(auto& topk_result : context_result) { results.swap(context_result);
QueryResult ids;
for(auto& pair : topk_result) {
ids.push_back(pair.second);
}
results.emplace_back(ids);
}
return Status::OK(); return Status::OK();
} }
......
...@@ -32,6 +32,8 @@ public: ...@@ -32,6 +32,8 @@ public:
virtual size_t Size() const = 0; virtual size_t Size() const = 0;
virtual size_t Dimension() const = 0;
virtual size_t PhysicalSize() const = 0; virtual size_t PhysicalSize() const = 0;
virtual Status Serialize() = 0; virtual Status Serialize() = 0;
......
...@@ -54,6 +54,10 @@ size_t FaissExecutionEngine::Size() const { ...@@ -54,6 +54,10 @@ size_t FaissExecutionEngine::Size() const {
return (size_t)(Count() * pIndex_->d)*sizeof(float); return (size_t)(Count() * pIndex_->d)*sizeof(float);
} }
size_t FaissExecutionEngine::Dimension() const {
return pIndex_->d;
}
size_t FaissExecutionEngine::PhysicalSize() const { size_t FaissExecutionEngine::PhysicalSize() const {
return (size_t)(Count() * pIndex_->d)*sizeof(float); return (size_t)(Count() * pIndex_->d)*sizeof(float);
} }
......
...@@ -38,6 +38,8 @@ public: ...@@ -38,6 +38,8 @@ public:
size_t Size() const override; size_t Size() const override;
size_t Dimension() const override;
size_t PhysicalSize() const override; size_t PhysicalSize() const override;
Status Serialize() override; Status Serialize() override;
......
...@@ -15,7 +15,7 @@ typedef long IDNumber; ...@@ -15,7 +15,7 @@ typedef long IDNumber;
typedef IDNumber* IDNumberPtr; typedef IDNumber* IDNumberPtr;
typedef std::vector<IDNumber> IDNumbers; typedef std::vector<IDNumber> IDNumbers;
typedef std::vector<IDNumber> QueryResult; typedef std::vector<std::pair<IDNumber, double>> QueryResult;
typedef std::vector<QueryResult> QueryResults; typedef std::vector<QueryResult> QueryResults;
......
...@@ -31,8 +31,8 @@ public: ...@@ -31,8 +31,8 @@ public:
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>; using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
const Id2IndexMap& GetIndexMap() const { return map_index_files_; } const Id2IndexMap& GetIndexMap() const { return map_index_files_; }
using Score2IdMap = std::map<float, int64_t>; using Id2ScoreMap = std::vector<std::pair<int64_t, double>>;
using ResultSet = std::vector<Score2IdMap>; using ResultSet = std::vector<Id2ScoreMap>;
const ResultSet& GetResult() const { return result_; } const ResultSet& GetResult() const { return result_; }
ResultSet& GetResult() { return result_; } ResultSet& GetResult() { return result_; }
......
...@@ -19,12 +19,29 @@ void ClusterResult(const std::vector<long> &output_ids, ...@@ -19,12 +19,29 @@ void ClusterResult(const std::vector<long> &output_ids,
SearchContext::ResultSet &result_set) { SearchContext::ResultSet &result_set) {
result_set.clear(); result_set.clear();
for (auto i = 0; i < nq; i++) { for (auto i = 0; i < nq; i++) {
SearchContext::Score2IdMap score2id; SearchContext::Id2ScoreMap id_score;
for (auto k = 0; k < topk; k++) { for (auto k = 0; k < topk; k++) {
uint64_t index = i * nq + k; uint64_t index = i * nq + k;
score2id.insert(std::make_pair(output_distence[index], output_ids[index])); id_score.push_back(std::make_pair(output_ids[index], output_distence[index]));
} }
result_set.emplace_back(score2id); result_set.emplace_back(id_score);
}
}
void MergeResult(SearchContext::Id2ScoreMap &score_src,
SearchContext::Id2ScoreMap &score_target,
uint64_t topk) {
for (auto& pair_src : score_src) {
for (auto iter = score_target.begin(); iter != score_target.end(); ++iter) {
if(pair_src.second > iter->second) {
score_target.insert(iter, pair_src);
}
}
}
//remove unused items
while (score_target.size() > topk) {
score_target.pop_back();
} }
} }
...@@ -42,18 +59,39 @@ void TopkResult(SearchContext::ResultSet &result_src, ...@@ -42,18 +59,39 @@ void TopkResult(SearchContext::ResultSet &result_src,
} }
for (size_t i = 0; i < result_src.size(); i++) { for (size_t i = 0; i < result_src.size(); i++) {
SearchContext::Score2IdMap &score2id_src = result_src[i]; SearchContext::Id2ScoreMap &score_src = result_src[i];
SearchContext::Score2IdMap &score2id_target = result_target[i]; SearchContext::Id2ScoreMap &score_target = result_target[i];
for (auto iter = score2id_src.begin(); iter != score2id_src.end(); ++iter) { MergeResult(score_src, score_target, topk);
score2id_target.insert(std::make_pair(iter->first, iter->second)); }
}
void CalcScore(uint64_t vector_count,
const float *vectors_data,
uint64_t dimension,
const SearchContext::ResultSet &result_src,
SearchContext::ResultSet &result_target) {
result_target.clear();
if(result_src.empty()){
return;
}
int vec_index = 0;
for(auto& result : result_src) {
const float * vec_data = vectors_data + vec_index*dimension;
double vec_len = 0;
for(uint64_t i = 0; i < dimension; i++) {
vec_len += vec_data[i]*vec_data[i];
} }
vec_index++;
//remove unused items SearchContext::Id2ScoreMap score_array;
while (score2id_target.size() > topk) { for(auto& pair : result) {
score2id_target.erase(score2id_target.rbegin()->first); score_array.push_back(std::make_pair(pair.first, (1 - pair.second/vec_len)*100.0));
} }
result_target.emplace_back(score_array);
} }
} }
} }
...@@ -78,10 +116,12 @@ bool SearchTask::DoSearch() { ...@@ -78,10 +116,12 @@ bool SearchTask::DoSearch() {
std::vector<long> output_ids; std::vector<long> output_ids;
std::vector<float> output_distence; std::vector<float> output_distence;
for(auto& context : search_contexts_) { for(auto& context : search_contexts_) {
//step 1: allocate memory
auto inner_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk(); auto inner_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
output_ids.resize(inner_k*context->nq()); output_ids.resize(inner_k*context->nq());
output_distence.resize(inner_k*context->nq()); output_distence.resize(inner_k*context->nq());
//step 2: search
try { try {
index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(), index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
output_ids.data()); output_ids.data());
...@@ -93,11 +133,21 @@ bool SearchTask::DoSearch() { ...@@ -93,11 +133,21 @@ bool SearchTask::DoSearch() {
rc.Record("do search"); rc.Record("do search");
//step 3: cluster result
SearchContext::ResultSet result_set; SearchContext::ResultSet result_set;
ClusterResult(output_ids, output_distence, context->nq(), inner_k, result_set); ClusterResult(output_ids, output_distence, context->nq(), inner_k, result_set);
rc.Record("cluster result"); rc.Record("cluster result");
//step 4: pick up topk result
TopkResult(result_set, inner_k, context->GetResult()); TopkResult(result_set, inner_k, context->GetResult());
rc.Record("reduce topk"); rc.Record("reduce topk");
//step 5: calculate score between 0 ~ 100
CalcScore(context->nq(), context->vectors(), index_engine_->Dimension(), context->GetResult(), result_set);
context->GetResult().swap(result_set);
rc.Record("calculate score");
//step 6: notify to send result to client
context->IndexSearchDone(index_id_); context->IndexSearchDone(index_id_);
} }
......
...@@ -400,9 +400,10 @@ ServerError SearchVectorTask::OnExecute() { ...@@ -400,9 +400,10 @@ ServerError SearchVectorTask::OnExecute() {
const auto& record = record_array_[i]; const auto& record = record_array_[i];
thrift::TopKQueryResult thrift_topk_result; thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) { for(auto& pair : result) {
thrift::QueryResult thrift_result; thrift::QueryResult thrift_result;
thrift_result.__set_id(id); thrift_result.__set_id(pair.first);
thrift_result.__set_score(pair.second);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result); thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
} }
......
...@@ -164,11 +164,11 @@ TEST_F(DBTest, DB_TEST) { ...@@ -164,11 +164,11 @@ TEST_F(DBTest, DB_TEST) {
ASSERT_STATS(stat); ASSERT_STATS(stat);
for (auto k=0; k<qb; ++k) { for (auto k=0; k<qb; ++k) {
ASSERT_EQ(results[k][0], target_ids[k]); ASSERT_EQ(results[k][0].first, target_ids[k]);
ss.str(""); ss.str("");
ss << "Result [" << k << "]:"; ss << "Result [" << k << "]:";
for (auto result : results[k]) { for (auto result : results[k]) {
ss << result << " "; ss << result.first << " ";
} }
/* LOG(DEBUG) << ss.str(); */ /* LOG(DEBUG) << ss.str(); */
} }
......
...@@ -87,11 +87,11 @@ TEST_F(DBTest, Metric_Tes) { ...@@ -87,11 +87,11 @@ TEST_F(DBTest, Metric_Tes) {
ASSERT_STATS(stat); ASSERT_STATS(stat);
for (auto k=0; k<qb; ++k) { for (auto k=0; k<qb; ++k) {
ASSERT_EQ(results[k][0], target_ids[k]); ASSERT_EQ(results[k][0].first, target_ids[k]);
ss.str(""); ss.str("");
ss << "Result [" << k << "]:"; ss << "Result [" << k << "]:";
for (auto result : results[k]) { for (auto result : results[k]) {
ss << result << " "; ss << result.first << " ";
} }
/* LOG(DEBUG) << ss.str(); */ /* LOG(DEBUG) << ss.str(); */
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册