提交 a1eda477 编写于 作者: J jinhai

Merge branch 'branch-0.3.0' into 'branch-0.3.0'

MS-52 - Return search score

See merge request megasearch/vecwise_engine!65

Former-commit-id: 29bd72ba9d7b3490046f3c8e6faf69a43bd10e1c
......@@ -12,7 +12,8 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-57 - Implement index load/search pipeline
- MS-56 - Add version information when server is started
- Ms-64 - Different table can have different index type
- MS-64 - Different table can have different index type
- MS-52 - Return search score
## Task
......
......@@ -240,7 +240,7 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq,
int inner_k = dis.size() < k ? dis.size() : k;
for (int i = 0; i < inner_k; ++i) {
res.emplace_back(nns[output_ids[i]]); // mapping
res.emplace_back(std::make_pair(nns[output_ids[i]], output_distence[i])); // mapping
}
results.push_back(res); // append to result list
res.clear();
......@@ -267,6 +267,8 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq,
Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq,
const float* vectors, const meta::DatesT& dates, QueryResults& results) {
//step 1: get files to search
meta::DatePartionedTableFilesSchema files;
auto status = pMeta_->FilesToSearch(table_id, dates, files);
if (!status.ok()) { return status; }
......@@ -282,18 +284,15 @@ Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq,
}
}
//step 2: put search task to scheduler
SearchScheduler& scheduler = SearchScheduler::GetInstance();
scheduler.ScheduleSearchTask(context);
context->WaitResult();
//step 3: construct results
auto& context_result = context->GetResult();
for(auto& topk_result : context_result) {
QueryResult ids;
for(auto& pair : topk_result) {
ids.push_back(pair.second);
}
results.emplace_back(ids);
}
results.swap(context_result);
return Status::OK();
}
......
......@@ -32,6 +32,8 @@ public:
virtual size_t Size() const = 0;
virtual size_t Dimension() const = 0;
virtual size_t PhysicalSize() const = 0;
virtual Status Serialize() = 0;
......
......@@ -54,6 +54,10 @@ size_t FaissExecutionEngine::Size() const {
return (size_t)(Count() * pIndex_->d)*sizeof(float);
}
size_t FaissExecutionEngine::Dimension() const {
return pIndex_->d;
}
size_t FaissExecutionEngine::PhysicalSize() const {
return (size_t)(Count() * pIndex_->d)*sizeof(float);
}
......
......@@ -38,6 +38,8 @@ public:
size_t Size() const override;
size_t Dimension() const override;
size_t PhysicalSize() const override;
Status Serialize() override;
......
......@@ -15,7 +15,7 @@ typedef long IDNumber;
typedef IDNumber* IDNumberPtr;
typedef std::vector<IDNumber> IDNumbers;
typedef std::vector<IDNumber> QueryResult;
typedef std::vector<std::pair<IDNumber, double>> QueryResult;
typedef std::vector<QueryResult> QueryResults;
......
......@@ -31,8 +31,8 @@ public:
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
const Id2IndexMap& GetIndexMap() const { return map_index_files_; }
using Score2IdMap = std::map<float, int64_t>;
using ResultSet = std::vector<Score2IdMap>;
using Id2ScoreMap = std::vector<std::pair<int64_t, double>>;
using ResultSet = std::vector<Id2ScoreMap>;
const ResultSet& GetResult() const { return result_; }
ResultSet& GetResult() { return result_; }
......
......@@ -19,12 +19,29 @@ void ClusterResult(const std::vector<long> &output_ids,
SearchContext::ResultSet &result_set) {
result_set.clear();
for (auto i = 0; i < nq; i++) {
SearchContext::Score2IdMap score2id;
SearchContext::Id2ScoreMap id_score;
for (auto k = 0; k < topk; k++) {
uint64_t index = i * nq + k;
score2id.insert(std::make_pair(output_distence[index], output_ids[index]));
id_score.push_back(std::make_pair(output_ids[index], output_distence[index]));
}
result_set.emplace_back(score2id);
result_set.emplace_back(id_score);
}
}
void MergeResult(SearchContext::Id2ScoreMap &score_src,
SearchContext::Id2ScoreMap &score_target,
uint64_t topk) {
for (auto& pair_src : score_src) {
for (auto iter = score_target.begin(); iter != score_target.end(); ++iter) {
if(pair_src.second > iter->second) {
score_target.insert(iter, pair_src);
}
}
}
//remove unused items
while (score_target.size() > topk) {
score_target.pop_back();
}
}
......@@ -42,18 +59,39 @@ void TopkResult(SearchContext::ResultSet &result_src,
}
for (size_t i = 0; i < result_src.size(); i++) {
SearchContext::Score2IdMap &score2id_src = result_src[i];
SearchContext::Score2IdMap &score2id_target = result_target[i];
for (auto iter = score2id_src.begin(); iter != score2id_src.end(); ++iter) {
score2id_target.insert(std::make_pair(iter->first, iter->second));
SearchContext::Id2ScoreMap &score_src = result_src[i];
SearchContext::Id2ScoreMap &score_target = result_target[i];
MergeResult(score_src, score_target, topk);
}
}
void CalcScore(uint64_t vector_count,
const float *vectors_data,
uint64_t dimension,
const SearchContext::ResultSet &result_src,
SearchContext::ResultSet &result_target) {
result_target.clear();
if(result_src.empty()){
return;
}
int vec_index = 0;
for(auto& result : result_src) {
const float * vec_data = vectors_data + vec_index*dimension;
double vec_len = 0;
for(uint64_t i = 0; i < dimension; i++) {
vec_len += vec_data[i]*vec_data[i];
}
vec_index++;
//remove unused items
while (score2id_target.size() > topk) {
score2id_target.erase(score2id_target.rbegin()->first);
SearchContext::Id2ScoreMap score_array;
for(auto& pair : result) {
score_array.push_back(std::make_pair(pair.first, (1 - pair.second/vec_len)*100.0));
}
result_target.emplace_back(score_array);
}
}
}
......@@ -78,10 +116,12 @@ bool SearchTask::DoSearch() {
std::vector<long> output_ids;
std::vector<float> output_distence;
for(auto& context : search_contexts_) {
//step 1: allocate memory
auto inner_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
output_ids.resize(inner_k*context->nq());
output_distence.resize(inner_k*context->nq());
//step 2: search
try {
index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
output_ids.data());
......@@ -93,11 +133,21 @@ bool SearchTask::DoSearch() {
rc.Record("do search");
//step 3: cluster result
SearchContext::ResultSet result_set;
ClusterResult(output_ids, output_distence, context->nq(), inner_k, result_set);
rc.Record("cluster result");
//step 4: pick up topk result
TopkResult(result_set, inner_k, context->GetResult());
rc.Record("reduce topk");
//step 5: calculate score between 0 ~ 100
CalcScore(context->nq(), context->vectors(), index_engine_->Dimension(), context->GetResult(), result_set);
context->GetResult().swap(result_set);
rc.Record("calculate score");
//step 6: notify to send result to client
context->IndexSearchDone(index_id_);
}
......
......@@ -400,9 +400,10 @@ ServerError SearchVectorTask::OnExecute() {
const auto& record = record_array_[i];
thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) {
for(auto& pair : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(id);
thrift_result.__set_id(pair.first);
thrift_result.__set_score(pair.second);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
......
......@@ -164,11 +164,11 @@ TEST_F(DBTest, DB_TEST) {
ASSERT_STATS(stat);
for (auto k=0; k<qb; ++k) {
ASSERT_EQ(results[k][0], target_ids[k]);
ASSERT_EQ(results[k][0].first, target_ids[k]);
ss.str("");
ss << "Result [" << k << "]:";
for (auto result : results[k]) {
ss << result << " ";
ss << result.first << " ";
}
/* LOG(DEBUG) << ss.str(); */
}
......
......@@ -87,11 +87,11 @@ TEST_F(DBTest, Metric_Tes) {
ASSERT_STATS(stat);
for (auto k=0; k<qb; ++k) {
ASSERT_EQ(results[k][0], target_ids[k]);
ASSERT_EQ(results[k][0].first, target_ids[k]);
ss.str("");
ss << "Result [" << k << "]:";
for (auto result : results[k]) {
ss << result << " ";
ss << result.first << " ";
}
/* LOG(DEBUG) << ss.str(); */
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册