diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index e4c6ce04f69d12cbb63eca1dc5b38100c488aaba..731ff306302c3befa5d3ae8c70579dab8a69d14e 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -12,7 +12,8 @@ Please mark all change in change log and use the ticket from JIRA. - MS-57 - Implement index load/search pipeline - MS-56 - Add version information when server is started -- Ms-64 - Different table can have different index type +- MS-64 - Different table can have different index type +- MS-52 - Return search score ## Task diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index dcc397345e77843396eb0d45daf2d829d8fa49d1..6211c688fba1d51178e7bd5cb485a26b4c7d148a 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -240,7 +240,7 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq, int inner_k = dis.size() < k ? dis.size() : k; for (int i = 0; i < inner_k; ++i) { - res.emplace_back(nns[output_ids[i]]); // mapping + res.emplace_back(std::make_pair(nns[output_ids[i]], output_distence[i])); // mapping } results.push_back(res); // append to result list res.clear(); @@ -267,6 +267,8 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq, Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq, const float* vectors, const meta::DatesT& dates, QueryResults& results) { + + //step 1: get files to search meta::DatePartionedTableFilesSchema files; auto status = pMeta_->FilesToSearch(table_id, dates, files); if (!status.ok()) { return status; } @@ -282,18 +284,15 @@ Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq, } } + //step 2: put search task to scheduler SearchScheduler& scheduler = SearchScheduler::GetInstance(); scheduler.ScheduleSearchTask(context); context->WaitResult(); + + //step 3: construct results auto& context_result = context->GetResult(); - for(auto& topk_result : context_result) { - QueryResult ids; - for(auto& pair : topk_result) { - ids.push_back(pair.second); - } - results.emplace_back(ids); - } + results.swap(context_result); return Status::OK(); } diff --git a/cpp/src/db/ExecutionEngine.h b/cpp/src/db/ExecutionEngine.h index fe1acd913d80b6a5b1e610f94facfd8de1e227ba..ad4355786f2322968e5620c3f7d036b212598de3 100644 --- a/cpp/src/db/ExecutionEngine.h +++ b/cpp/src/db/ExecutionEngine.h @@ -32,6 +32,8 @@ public: virtual size_t Size() const = 0; + virtual size_t Dimension() const = 0; + virtual size_t PhysicalSize() const = 0; virtual Status Serialize() = 0; diff --git a/cpp/src/db/FaissExecutionEngine.cpp b/cpp/src/db/FaissExecutionEngine.cpp index a338ddf5cbb910f0f43febe737f59fc130cec582..b25a3150edd866b2ca3d1627d737b34cae047d77 100644 --- a/cpp/src/db/FaissExecutionEngine.cpp +++ b/cpp/src/db/FaissExecutionEngine.cpp @@ -54,6 +54,10 @@ size_t FaissExecutionEngine::Size() const { return (size_t)(Count() * pIndex_->d)*sizeof(float); } +size_t FaissExecutionEngine::Dimension() const { + return pIndex_->d; +} + size_t FaissExecutionEngine::PhysicalSize() const { return (size_t)(Count() * pIndex_->d)*sizeof(float); } diff --git a/cpp/src/db/FaissExecutionEngine.h b/cpp/src/db/FaissExecutionEngine.h index d1981502b9e25a30ed6849b4f6c1f637e918fd9e..e41fe064569d044c0cb270c07cf20f1a17f24870 100644 --- a/cpp/src/db/FaissExecutionEngine.h +++ b/cpp/src/db/FaissExecutionEngine.h @@ -38,6 +38,8 @@ public: size_t Size() const override; + size_t Dimension() const override; + size_t PhysicalSize() const override; Status Serialize() override; diff --git a/cpp/src/db/Types.h b/cpp/src/db/Types.h index f9a432fd94e3ee1decb50b7999b2ca0cec61f5bb..73ecc81fa8da14cb3b1c9e3846de2aaba90c844c 100644 --- a/cpp/src/db/Types.h +++ b/cpp/src/db/Types.h @@ -15,7 +15,7 @@ typedef long IDNumber; typedef IDNumber* IDNumberPtr; typedef std::vector IDNumbers; -typedef std::vector QueryResult; +typedef std::vector> QueryResult; typedef std::vector QueryResults; diff --git a/cpp/src/db/scheduler/SearchContext.h b/cpp/src/db/scheduler/SearchContext.h index ae7327fd6829eee012b123296521044f0957f94b..b212ea34d99d15fed7ea5d0dd8e2134c61372d26 100644 --- a/cpp/src/db/scheduler/SearchContext.h +++ b/cpp/src/db/scheduler/SearchContext.h @@ -31,8 +31,8 @@ public: using Id2IndexMap = std::unordered_map; const Id2IndexMap& GetIndexMap() const { return map_index_files_; } - using Score2IdMap = std::map; - using ResultSet = std::vector; + using Id2ScoreMap = std::vector>; + using ResultSet = std::vector; const ResultSet& GetResult() const { return result_; } ResultSet& GetResult() { return result_; } diff --git a/cpp/src/db/scheduler/SearchTaskQueue.cpp b/cpp/src/db/scheduler/SearchTaskQueue.cpp index 86478477d1f2008c0bb04ac2934f1f55cbe57a34..38db5fd7a7cf675e9bd4496f2e60c14c05f88748 100644 --- a/cpp/src/db/scheduler/SearchTaskQueue.cpp +++ b/cpp/src/db/scheduler/SearchTaskQueue.cpp @@ -19,12 +19,29 @@ void ClusterResult(const std::vector &output_ids, SearchContext::ResultSet &result_set) { result_set.clear(); for (auto i = 0; i < nq; i++) { - SearchContext::Score2IdMap score2id; + SearchContext::Id2ScoreMap id_score; for (auto k = 0; k < topk; k++) { uint64_t index = i * nq + k; - score2id.insert(std::make_pair(output_distence[index], output_ids[index])); + id_score.push_back(std::make_pair(output_ids[index], output_distence[index])); } - result_set.emplace_back(score2id); + result_set.emplace_back(id_score); + } +} + +void MergeResult(SearchContext::Id2ScoreMap &score_src, + SearchContext::Id2ScoreMap &score_target, + uint64_t topk) { + for (auto& pair_src : score_src) { + for (auto iter = score_target.begin(); iter != score_target.end(); ++iter) { + if(pair_src.second > iter->second) { + score_target.insert(iter, pair_src); + } + } + } + + //remove unused items + while (score_target.size() > topk) { + score_target.pop_back(); } } @@ -42,18 +59,39 @@ void TopkResult(SearchContext::ResultSet &result_src, } for (size_t i = 0; i < result_src.size(); i++) { - SearchContext::Score2IdMap &score2id_src = result_src[i]; - SearchContext::Score2IdMap &score2id_target = result_target[i]; - for (auto iter = score2id_src.begin(); iter != score2id_src.end(); ++iter) { - score2id_target.insert(std::make_pair(iter->first, iter->second)); + SearchContext::Id2ScoreMap &score_src = result_src[i]; + SearchContext::Id2ScoreMap &score_target = result_target[i]; + MergeResult(score_src, score_target, topk); + } +} + +void CalcScore(uint64_t vector_count, + const float *vectors_data, + uint64_t dimension, + const SearchContext::ResultSet &result_src, + SearchContext::ResultSet &result_target) { + result_target.clear(); + if(result_src.empty()){ + return; + } + + int vec_index = 0; + for(auto& result : result_src) { + const float * vec_data = vectors_data + vec_index*dimension; + double vec_len = 0; + for(uint64_t i = 0; i < dimension; i++) { + vec_len += vec_data[i]*vec_data[i]; } + vec_index++; - //remove unused items - while (score2id_target.size() > topk) { - score2id_target.erase(score2id_target.rbegin()->first); + SearchContext::Id2ScoreMap score_array; + for(auto& pair : result) { + score_array.push_back(std::make_pair(pair.first, (1 - pair.second/vec_len)*100.0)); } + result_target.emplace_back(score_array); } } + } @@ -78,10 +116,12 @@ bool SearchTask::DoSearch() { std::vector output_ids; std::vector output_distence; for(auto& context : search_contexts_) { + //step 1: allocate memory auto inner_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk(); output_ids.resize(inner_k*context->nq()); output_distence.resize(inner_k*context->nq()); + //step 2: search try { index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(), output_ids.data()); @@ -93,11 +133,21 @@ bool SearchTask::DoSearch() { rc.Record("do search"); + //step 3: cluster result SearchContext::ResultSet result_set; ClusterResult(output_ids, output_distence, context->nq(), inner_k, result_set); rc.Record("cluster result"); + + //step 4: pick up topk result TopkResult(result_set, inner_k, context->GetResult()); rc.Record("reduce topk"); + + //step 5: calculate score between 0 ~ 100 + CalcScore(context->nq(), context->vectors(), index_engine_->Dimension(), context->GetResult(), result_set); + context->GetResult().swap(result_set); + rc.Record("calculate score"); + + //step 6: notify to send result to client context->IndexSearchDone(index_id_); } diff --git a/cpp/src/server/MegasearchTask.cpp b/cpp/src/server/MegasearchTask.cpp index 2980deb6fa100b6bf3548b65ea367b45226ec237..7c78b100468a3269400df7292af175180ac97202 100644 --- a/cpp/src/server/MegasearchTask.cpp +++ b/cpp/src/server/MegasearchTask.cpp @@ -400,9 +400,10 @@ ServerError SearchVectorTask::OnExecute() { const auto& record = record_array_[i]; thrift::TopKQueryResult thrift_topk_result; - for(auto id : result) { + for(auto& pair : result) { thrift::QueryResult thrift_result; - thrift_result.__set_id(id); + thrift_result.__set_id(pair.first); + thrift_result.__set_score(pair.second); thrift_topk_result.query_result_arrays.emplace_back(thrift_result); } diff --git a/cpp/unittest/db/db_tests.cpp b/cpp/unittest/db/db_tests.cpp index 459acf9ab7a2b5be18421870b48e0e614182cbf2..c903a7b9573933c2606c86e1bc1db0c12067e041 100644 --- a/cpp/unittest/db/db_tests.cpp +++ b/cpp/unittest/db/db_tests.cpp @@ -164,11 +164,11 @@ TEST_F(DBTest, DB_TEST) { ASSERT_STATS(stat); for (auto k=0; k