diff --git a/CHANGELOG.md b/CHANGELOG.md index bd6edf52f6b2b4113b7411e50be2ce8b439c99e6..419a12339090b55a36510758686bb74ac530aeaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ Please mark all change in change log and use the ticket from JIRA. - \#149 - Improve large query optimizer pass - \#156 - Not return error when search_resources and index_build_device set cpu - \#159 - Change the configuration name from 'use_gpu_threshold' to 'gpu_search_threshold' +- \#168 - Improve result reduce ## Task diff --git a/core/src/db/DB.h b/core/src/db/DB.h index a790fadb502f3f1d93d879873daa8a5adb228c44..07fe30babd911a64b596ff321bdddde814f745c2 100644 --- a/core/src/db/DB.h +++ b/core/src/db/DB.h @@ -67,15 +67,16 @@ class DB { virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, - QueryResults& results) = 0; + ResultIds& result_ids, ResultDistances& result_distances) = 0; virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, - const meta::DatesT& dates, QueryResults& results) = 0; + const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) = 0; virtual Status Query(const std::string& table_id, const std::vector& file_ids, uint64_t k, uint64_t nq, - uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) = 0; + uint64_t nprobe, const float* vectors, const meta::DatesT& dates, ResultIds& result_ids, + ResultDistances& result_distances) = 0; virtual Status Size(uint64_t& result) = 0; diff --git a/core/src/db/DBImpl.cpp b/core/src/db/DBImpl.cpp index 6995de3d1407e76c9a73f30a096c45c278cbff23..fc31846bd391f26fe423649a70ad0d0fe22590c8 100644 --- a/core/src/db/DBImpl.cpp +++ b/core/src/db/DBImpl.cpp @@ -336,20 +336,20 @@ DBImpl::DropIndex(const std::string& table_id) { Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, - QueryResults& results) { + ResultIds& result_ids, ResultDistances& result_distances) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } meta::DatesT dates = {utils::GetDate()}; - Status result = Query(table_id, k, nq, nprobe, vectors, dates, results); + Status result = Query(table_id, k, nq, nprobe, vectors, dates, result_ids, result_distances); return result; } Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, - const meta::DatesT& dates, QueryResults& results) { + const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } @@ -372,14 +372,15 @@ DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t npr } cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query - status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, results); + status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, result_ids, result_distances); cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query return status; } Status DBImpl::Query(const std::string& table_id, const std::vector& file_ids, uint64_t k, uint64_t nq, - uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) { + uint64_t nprobe, const float* vectors, const meta::DatesT& dates, ResultIds& result_ids, + ResultDistances& result_distances) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } @@ -413,7 +414,7 @@ DBImpl::Query(const std::string& table_id, const std::vector& file_ } cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query - status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, results); + status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, result_ids, result_distances); cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query return status; } @@ -432,7 +433,7 @@ DBImpl::Size(uint64_t& result) { /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// Status DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, - uint64_t nprobe, const float* vectors, QueryResults& results) { + uint64_t nprobe, const float* vectors, ResultIds& result_ids, ResultDistances& result_distances) { server::CollectQueryMetrics metrics(nq); TimeRecorder rc(""); @@ -453,7 +454,8 @@ DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& fi } // step 3: construct results - results = job->GetResult(); + result_ids = job->GetResultIds(); + result_distances = job->GetResultDistances(); rc.ElapseFromBegin("Engine query totally cost"); return Status::OK(); diff --git a/core/src/db/DBImpl.h b/core/src/db/DBImpl.h index e1e030cc32651c874efecf24719e20ecfac423f6..ad9c574bb1b583b02da4642786ac9a13ff0ebdf7 100644 --- a/core/src/db/DBImpl.h +++ b/core/src/db/DBImpl.h @@ -91,15 +91,16 @@ class DBImpl : public DB { Status Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, - QueryResults& results) override; + ResultIds& result_ids, ResultDistances& result_distances) override; Status Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, - const meta::DatesT& dates, QueryResults& results) override; + const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) override; Status Query(const std::string& table_id, const std::vector& file_ids, uint64_t k, uint64_t nq, - uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) override; + uint64_t nprobe, const float* vectors, const meta::DatesT& dates, ResultIds& result_ids, + ResultDistances& result_distances) override; Status Size(uint64_t& result) override; @@ -107,7 +108,7 @@ class DBImpl : public DB { private: Status QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, - uint64_t nprobe, const float* vectors, QueryResults& results); + uint64_t nprobe, const float* vectors, ResultIds& result_ids, ResultDistances& result_distances); void BackgroundTimerTask(); diff --git a/core/src/db/Types.h b/core/src/db/Types.h index 94528a9a8afb0023ec09ad3d12603c1a3f0a6746..cc2eab03831ea7d3bb052aef07faeba02afb99ea 100644 --- a/core/src/db/Types.h +++ b/core/src/db/Types.h @@ -19,6 +19,7 @@ #include "db/engine/ExecutionEngine.h" +#include #include #include #include @@ -26,12 +27,13 @@ namespace milvus { namespace engine { -typedef int64_t IDNumber; +using IDNumber = faiss::Index::idx_t; + typedef IDNumber* IDNumberPtr; typedef std::vector IDNumbers; -typedef std::vector> QueryResult; -typedef std::vector QueryResults; +typedef std::vector ResultIds; +typedef std::vector ResultDistances; struct TableIndex { int32_t engine_type_ = (int)EngineType::FAISS_IDMAP; diff --git a/core/src/scheduler/job/SearchJob.cpp b/core/src/scheduler/job/SearchJob.cpp index 47c825c12242c53c803f1dcd5fdd92c13e8bc21f..ee41c1ae0605b1e2a32566f09f1986b0c5291287 100644 --- a/core/src/scheduler/job/SearchJob.cpp +++ b/core/src/scheduler/job/SearchJob.cpp @@ -53,9 +53,14 @@ SearchJob::SearchDone(size_t index_id) { SERVER_LOG_DEBUG << "SearchJob " << id() << " finish index file: " << index_id; } -ResultSet& -SearchJob::GetResult() { - return result_; +ResultIds& +SearchJob::GetResultIds() { + return result_ids_; +} + +ResultDistances& +SearchJob::GetResultDistances() { + return result_distances_; } Status& diff --git a/core/src/scheduler/job/SearchJob.h b/core/src/scheduler/job/SearchJob.h index 90fcf367733e4e4a58af924fbc48a5f3300d34c3..ff5ab34131e83c26019c2462d2d63bfade6f4100 100644 --- a/core/src/scheduler/job/SearchJob.h +++ b/core/src/scheduler/job/SearchJob.h @@ -29,6 +29,7 @@ #include #include "Job.h" +#include "db/Types.h" #include "db/meta/MetaTypes.h" namespace milvus { @@ -37,9 +38,9 @@ namespace scheduler { using engine::meta::TableFileSchemaPtr; using Id2IndexMap = std::unordered_map; -using IdDistPair = std::pair; -using Id2DistVec = std::vector; -using ResultSet = std::vector; + +using ResultIds = engine::ResultIds; +using ResultDistances = engine::ResultDistances; class SearchJob : public Job { public: @@ -55,8 +56,11 @@ class SearchJob : public Job { void SearchDone(size_t index_id); - ResultSet& - GetResult(); + ResultIds& + GetResultIds(); + + ResultDistances& + GetResultDistances(); Status& GetStatus(); @@ -104,7 +108,8 @@ class SearchJob : public Job { Id2IndexMap index_files_; // TODO: column-base better ? - ResultSet result_; + ResultIds result_ids_; + ResultDistances result_distances_; Status status_; std::mutex mutex_; diff --git a/core/src/scheduler/task/SearchTask.cpp b/core/src/scheduler/task/SearchTask.cpp index edeb41bdbe5131a2fa160d33a444284c6b8265bb..08bc6525aa968a1c374544dd0ab348b62eed0de7 100644 --- a/core/src/scheduler/task/SearchTask.cpp +++ b/core/src/scheduler/task/SearchTask.cpp @@ -222,7 +222,7 @@ XSearchTask::Execute() { { std::unique_lock lock(search_job->mutex()); XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2, - search_job->GetResult()); + search_job->GetResultIds(), search_job->GetResultDistances()); } span = rc.RecordSection(hdr + ", reduce topk"); @@ -243,71 +243,69 @@ XSearchTask::Execute() { } void -XSearchTask::MergeTopkToResultSet(const std::vector& input_ids, const std::vector& input_distance, - uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, - scheduler::ResultSet& result) { - if (result.empty()) { - result.resize(nq); +XSearchTask::MergeTopkToResultSet(const scheduler::ResultIds& src_ids, const scheduler::ResultDistances& src_distances, + size_t src_k, size_t nq, size_t topk, bool ascending, scheduler::ResultIds& tar_ids, + scheduler::ResultDistances& tar_distances) { + if (src_ids.empty()) { + return; } + size_t tar_k = tar_ids.size() / nq; + size_t buf_k = std::min(topk, src_k + tar_k); + + scheduler::ResultIds buf_ids(nq * buf_k, -1); + scheduler::ResultDistances buf_distances(nq * buf_k, 0.0); + for (uint64_t i = 0; i < nq; i++) { - scheduler::Id2DistVec result_buf; - auto& result_i = result[i]; - - if (result[i].empty()) { - result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0)); - uint64_t input_k_multi_i = topk * i; - for (auto k = 0; k < input_k; ++k) { - uint64_t idx = input_k_multi_i + k; - auto& result_buf_item = result_buf[k]; - result_buf_item.first = input_ids[idx]; - result_buf_item.second = input_distance[idx]; - } - } else { - size_t tar_size = result_i.size(); - uint64_t output_k = std::min(topk, input_k + tar_size); - result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0)); - size_t buf_k = 0, src_k = 0, tar_k = 0; - uint64_t src_idx; - uint64_t input_k_multi_i = topk * i; - while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { - src_idx = input_k_multi_i + src_k; - auto& result_buf_item = result_buf[buf_k]; - auto& result_item = result_i[tar_k]; - if ((ascending && input_distance[src_idx] < result_item.second) || - (!ascending && input_distance[src_idx] > result_item.second)) { - result_buf_item.first = input_ids[src_idx]; - result_buf_item.second = input_distance[src_idx]; - src_k++; - } else { - result_buf_item = result_item; - tar_k++; - } - buf_k++; + size_t buf_k_j = 0, src_k_j = 0, tar_k_j = 0; + size_t buf_idx, src_idx, tar_idx; + + size_t buf_k_multi_i = buf_k * i; + size_t src_k_multi_i = topk * i; + size_t tar_k_multi_i = tar_k * i; + + while (buf_k_j < buf_k && src_k_j < src_k && tar_k_j < tar_k) { + src_idx = src_k_multi_i + src_k_j; + tar_idx = tar_k_multi_i + tar_k_j; + buf_idx = buf_k_multi_i + buf_k_j; + + if ((ascending && src_distances[src_idx] < tar_distances[tar_idx]) || + (!ascending && src_distances[src_idx] > tar_distances[tar_idx])) { + buf_ids[buf_idx] = src_ids[src_idx]; + buf_distances[buf_idx] = src_distances[src_idx]; + src_k_j++; + } else { + buf_ids[buf_idx] = tar_ids[tar_idx]; + buf_distances[buf_idx] = tar_distances[tar_idx]; + tar_k_j++; } + buf_k_j++; + } - if (buf_k < output_k) { - if (src_k < input_k) { - while (buf_k < output_k && src_k < input_k) { - src_idx = input_k_multi_i + src_k; - auto& result_buf_item = result_buf[buf_k]; - result_buf_item.first = input_ids[src_idx]; - result_buf_item.second = input_distance[src_idx]; - src_k++; - buf_k++; - } - } else { - while (buf_k < output_k && tar_k < tar_size) { - result_buf[buf_k] = result_i[tar_k]; - tar_k++; - buf_k++; - } + if (buf_k_j < buf_k) { + if (src_k_j < src_k) { + while (buf_k_j < buf_k && src_k_j < src_k) { + buf_idx = buf_k_multi_i + buf_k_j; + src_idx = src_k_multi_i + src_k_j; + buf_ids[buf_idx] = src_ids[src_idx]; + buf_distances[buf_idx] = src_distances[src_idx]; + src_k_j++; + buf_k_j++; + } + } else { + while (buf_k_j < buf_k && tar_k_j < tar_k) { + buf_idx = buf_k_multi_i + buf_k_j; + tar_idx = tar_k_multi_i + tar_k_j; + buf_ids[buf_idx] = tar_ids[tar_idx]; + buf_distances[buf_idx] = tar_distances[tar_idx]; + tar_k_j++; + buf_k_j++; } } } - - result_i.swap(result_buf); } + tar_ids.swap(buf_ids); + tar_distances.swap(buf_distances); } // void diff --git a/core/src/scheduler/task/SearchTask.h b/core/src/scheduler/task/SearchTask.h index bbc8b5bd8fc04b59a8ed9eafff8dfe6c46255adb..bd5113734178e2bb589b4fd956a61d57771b95f0 100644 --- a/core/src/scheduler/task/SearchTask.h +++ b/core/src/scheduler/task/SearchTask.h @@ -39,8 +39,9 @@ class XSearchTask : public Task { public: static void - MergeTopkToResultSet(const std::vector& input_ids, const std::vector& input_distance, - uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); + MergeTopkToResultSet(const scheduler::ResultIds& src_ids, const scheduler::ResultDistances& src_distances, + size_t src_k, size_t nq, size_t topk, bool ascending, scheduler::ResultIds& tar_ids, + scheduler::ResultDistances& tar_distances); // static void // MergeTopkArray(std::vector& tar_ids, std::vector& tar_distance, uint64_t& tar_input_k, diff --git a/core/src/server/grpc_impl/GrpcRequestTask.cpp b/core/src/server/grpc_impl/GrpcRequestTask.cpp index be1fca0186185d3b58325864e8b9d40791422cda..77f262bda610d04d2ee4521e20fb78f27d96eda3 100644 --- a/core/src/server/grpc_impl/GrpcRequestTask.cpp +++ b/core/src/server/grpc_impl/GrpcRequestTask.cpp @@ -637,7 +637,8 @@ SearchTask::OnExecute() { rc.RecordSection("prepare vector data"); // step 6: search vectors - engine::QueryResults results; + engine::ResultIds result_ids; + engine::ResultDistances result_distances; auto record_count = (uint64_t)search_param_->query_record_array().size(); #ifdef MILVUS_ENABLE_PROFILING @@ -647,11 +648,11 @@ SearchTask::OnExecute() { #endif if (file_id_array_.empty()) { - status = - DBWrapper::DB()->Query(table_name_, (size_t)top_k, record_count, nprobe, vec_f.data(), dates, results); + status = DBWrapper::DB()->Query(table_name_, (size_t)top_k, record_count, nprobe, vec_f.data(), dates, + result_ids, result_distances); } else { status = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t)top_k, record_count, nprobe, - vec_f.data(), dates, results); + vec_f.data(), dates, result_ids, result_distances); } #ifdef MILVUS_ENABLE_PROFILING @@ -663,23 +664,20 @@ SearchTask::OnExecute() { return status; } - if (results.empty()) { + if (result_ids.empty()) { return Status::OK(); // empty table } - if (results.size() != record_count) { - std::string msg = "Search " + std::to_string(record_count) + " vectors but only return " + - std::to_string(results.size()) + " results"; - return Status(SERVER_ILLEGAL_SEARCH_RESULT, msg); - } + size_t result_k = result_ids.size() / record_count; // step 7: construct result array - for (auto& result : results) { + for (size_t i = 0; i < record_count; i++) { ::milvus::grpc::TopKQueryResult* topk_query_result = topk_result_list->add_topk_query_result(); - for (auto& pair : result) { + for (size_t j = 0; j < result_k; j++) { ::milvus::grpc::QueryResult* grpc_result = topk_query_result->add_query_result_arrays(); - grpc_result->set_id(pair.first); - grpc_result->set_distance(pair.second); + size_t idx = i * result_k + j; + grpc_result->set_id(result_ids[idx]); + grpc_result->set_distance(result_distances[idx]); } } diff --git a/core/unittest/db/test_db.cpp b/core/unittest/db/test_db.cpp index b869d173889cf91dd648a7ccc4371af46193ad20..f9e8da9c0fc72f99b03ad367c9eadf6a973ed349 100644 --- a/core/unittest/db/test_db.cpp +++ b/core/unittest/db/test_db.cpp @@ -175,7 +175,8 @@ TEST_F(DBTest, DB_TEST) { BuildVectors(qb, qxb); std::thread search([&]() { - milvus::engine::QueryResults results; + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; int k = 10; std::this_thread::sleep_for(std::chrono::seconds(2)); @@ -190,17 +191,17 @@ TEST_F(DBTest, DB_TEST) { prev_count = count; START_TIMER; - stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); + stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), result_ids, result_distances); ss << "Search " << j << " With Size " << count / milvus::engine::M << " M"; STOP_TIMER(ss.str()); ASSERT_TRUE(stat.ok()); - for (auto k = 0; k < qb; ++k) { - ASSERT_EQ(results[k][0].first, target_ids[k]); + for (auto i = 0; i < qb; ++i) { + ASSERT_EQ(result_ids[i*k], target_ids[i]); ss.str(""); - ss << "Result [" << k << "]:"; - for (auto result : results[k]) { - ss << result.first << " "; + ss << "Result [" << i << "]:"; + for (auto t = 0; t < k; t++) { + ss << result_ids[i * k + t] << " "; } /* LOG(DEBUG) << ss.str(); */ } @@ -284,16 +285,18 @@ TEST_F(DBTest, SEARCH_TEST) { db_->CreateIndex(TABLE_NAME, index); // wait until build index finish { - milvus::engine::QueryResults results; - stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; + stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), result_ids, result_distances); ASSERT_TRUE(stat.ok()); } {//search by specify index file milvus::engine::meta::DatesT dates; std::vector file_ids = {"1", "2", "3", "4", "5", "6"}; - milvus::engine::QueryResults results; - stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results); + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; + stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids, result_distances); ASSERT_TRUE(stat.ok()); } @@ -303,22 +306,25 @@ TEST_F(DBTest, SEARCH_TEST) { db_->CreateIndex(TABLE_NAME, index); // wait until build index finish { - milvus::engine::QueryResults results; - stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; + stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), result_ids, result_distances); ASSERT_TRUE(stat.ok()); } { - milvus::engine::QueryResults large_nq_results; - stat = db_->Query(TABLE_NAME, k, 200, 10, xq.data(), large_nq_results); + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; + stat = db_->Query(TABLE_NAME, k, 200, 10, xq.data(), result_ids, result_distances); ASSERT_TRUE(stat.ok()); } {//search by specify index file milvus::engine::meta::DatesT dates; std::vector file_ids = {"1", "2", "3", "4", "5", "6"}; - milvus::engine::QueryResults results; - stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results); + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; + stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids, result_distances); ASSERT_TRUE(stat.ok()); } @@ -391,11 +397,12 @@ TEST_F(DBTest, SHUTDOWN_TEST) { ASSERT_FALSE(stat.ok()); milvus::engine::meta::DatesT dates; - milvus::engine::QueryResults results; - stat = db_->Query(table_info.table_id_, 1, 1, 1, nullptr, dates, results); + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; + stat = db_->Query(table_info.table_id_, 1, 1, 1, nullptr, dates, result_ids, result_distances); ASSERT_FALSE(stat.ok()); std::vector file_ids; - stat = db_->Query(table_info.table_id_, file_ids, 1, 1, 1, nullptr, dates, results); + stat = db_->Query(table_info.table_id_, file_ids, 1, 1, 1, nullptr, dates, result_ids, result_distances); ASSERT_FALSE(stat.ok()); stat = db_->DeleteTable(table_info.table_id_, dates); diff --git a/core/unittest/db/test_db_mysql.cpp b/core/unittest/db/test_db_mysql.cpp index ae1da8012aaedbc6c1fec6710994516144b9079e..30a616e66274a83200b8ff9f6269a5492c77d794 100644 --- a/core/unittest/db/test_db_mysql.cpp +++ b/core/unittest/db/test_db_mysql.cpp @@ -81,7 +81,8 @@ TEST_F(MySqlDBTest, DB_TEST) { ASSERT_EQ(target_ids.size(), qb); std::thread search([&]() { - milvus::engine::QueryResults results; + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; int k = 10; std::this_thread::sleep_for(std::chrono::seconds(5)); @@ -96,25 +97,25 @@ TEST_F(MySqlDBTest, DB_TEST) { prev_count = count; START_TIMER; - stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); + stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), result_ids, result_distances); ss << "Search " << j << " With Size " << count / milvus::engine::M << " M"; STOP_TIMER(ss.str()); ASSERT_TRUE(stat.ok()); - for (auto k = 0; k < qb; ++k) { + for (auto i = 0; i < qb; ++i) { // std::cout << results[k][0].first << " " << target_ids[k] << std::endl; // ASSERT_EQ(results[k][0].first, target_ids[k]); bool exists = false; - for (auto &result : results[k]) { - if (result.first == target_ids[k]) { + for (auto t = 0; t < k; t++) { + if (result_ids[i * k + t] == target_ids[i]) { exists = true; } } ASSERT_TRUE(exists); ss.str(""); - ss << "Result [" << k << "]:"; - for (auto result : results[k]) { - ss << result.first << " "; + ss << "Result [" << i << "]:"; + for (auto t = 0; t < k; t++) { + ss << result_ids[i * k + t] << " "; } /* LOG(DEBUG) << ss.str(); */ } @@ -188,8 +189,9 @@ TEST_F(MySqlDBTest, SEARCH_TEST) { sleep(2); // wait until build index finish - milvus::engine::QueryResults results; - stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; + stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), result_ids, result_distances); ASSERT_TRUE(stat.ok()); } diff --git a/core/unittest/db/test_mem.cpp b/core/unittest/db/test_mem.cpp index e05811ff9e050d8e11a946f63d67dc467de13da5..939e61246cf1140ab7fc978583c384a3b2aa5e3a 100644 --- a/core/unittest/db/test_mem.cpp +++ b/core/unittest/db/test_mem.cpp @@ -259,10 +259,11 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) { int topk = 10, nprobe = 10; for (auto& pair : search_vectors) { auto& search = pair.second; - milvus::engine::QueryResults results; - stat = db_->Query(GetTableName(), topk, 1, nprobe, search.data(), results); - ASSERT_EQ(results[0][0].first, pair.first); - ASSERT_LT(results[0][0].second, 1e-4); + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; + stat = db_->Query(GetTableName(), topk, 1, nprobe, search.data(), result_ids, result_distances); + ASSERT_EQ(result_ids[0], pair.first); + ASSERT_LT(result_distances[0], 1e-4); } } @@ -314,7 +315,8 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) { BuildVectors(qb, qxb); std::thread search([&]() { - milvus::engine::QueryResults results; + milvus::engine::ResultIds result_ids; + milvus::engine::ResultDistances result_distances; int k = 10; std::this_thread::sleep_for(std::chrono::seconds(2)); @@ -329,17 +331,17 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) { prev_count = count; START_TIMER; - stat = db_->Query(GetTableName(), k, qb, 10, qxb.data(), results); + stat = db_->Query(GetTableName(), k, qb, 10, qxb.data(), result_ids, result_distances); ss << "Search " << j << " With Size " << count / milvus::engine::M << " M"; STOP_TIMER(ss.str()); ASSERT_TRUE(stat.ok()); - for (auto k = 0; k < qb; ++k) { - ASSERT_EQ(results[k][0].first, target_ids[k]); + for (auto i = 0; i < qb; ++i) { + ASSERT_EQ(result_ids[i * k], target_ids[i]); ss.str(""); - ss << "Result [" << k << "]:"; - for (auto result : results[k]) { - ss << result.first << " "; + ss << "Result [" << i << "]:"; + for (auto t = 0; t < k; t++) { + ss << result_ids[i * k + t] << " "; } /* LOG(DEBUG) << ss.str(); */ } diff --git a/core/unittest/db/test_search.cpp b/core/unittest/db/test_search.cpp index b8cf08b3e2fc8d982027022619a40428c2d001a0..1d1d9a677add4bca762336727effa0c419516248 100644 --- a/core/unittest/db/test_search.cpp +++ b/core/unittest/db/test_search.cpp @@ -19,6 +19,7 @@ #include #include +#include "scheduler/job/SearchJob.h" #include "scheduler/task/SearchTask.h" #include "utils/TimeRecorder.h" #include "utils/ThreadPool.h" @@ -28,74 +29,80 @@ namespace { namespace ms = milvus::scheduler; void -BuildResult(std::vector& output_ids, - std::vector& output_distance, - uint64_t input_k, - uint64_t topk, - uint64_t nq, +BuildResult(ms::ResultIds& output_ids, + ms::ResultDistances & output_distances, + size_t input_k, + size_t topk, + size_t nq, bool ascending) { output_ids.clear(); output_ids.resize(nq * topk); - output_distance.clear(); - output_distance.resize(nq * topk); + output_distances.clear(); + output_distances.resize(nq * topk); - for (uint64_t i = 0; i < nq; i++) { + for (size_t i = 0; i < nq; i++) { //insert valid items - for (uint64_t j = 0; j < input_k; j++) { + for (size_t j = 0; j < input_k; j++) { output_ids[i * topk + j] = (int64_t)(drand48() * 100000); - output_distance[i * topk + j] = ascending ? (j + drand48()) : ((input_k - j) + drand48()); + output_distances[i * topk + j] = ascending ? (j + drand48()) : ((input_k - j) + drand48()); } //insert invalid items - for (uint64_t j = input_k; j < topk; j++) { + for (size_t j = input_k; j < topk; j++) { output_ids[i * topk + j] = -1; - output_distance[i * topk + j] = -1.0; + output_distances[i * topk + j] = -1.0; } } } void -CopyResult(std::vector& output_ids, - std::vector& output_distance, - uint64_t output_topk, - std::vector& input_ids, - std::vector& input_distance, - uint64_t input_topk, - uint64_t nq) { +CopyResult(ms::ResultIds& output_ids, + ms::ResultDistances& output_distances, + size_t output_topk, + ms::ResultIds& input_ids, + ms::ResultDistances& input_distances, + size_t input_topk, + size_t nq) { ASSERT_TRUE(input_ids.size() >= nq * input_topk); - ASSERT_TRUE(input_distance.size() >= nq * input_topk); + ASSERT_TRUE(input_distances.size() >= nq * input_topk); ASSERT_TRUE(output_topk <= input_topk); output_ids.clear(); output_ids.resize(nq * output_topk); - output_distance.clear(); - output_distance.resize(nq * output_topk); + output_distances.clear(); + output_distances.resize(nq * output_topk); - for (uint64_t i = 0; i < nq; i++) { - for (uint64_t j = 0; j < output_topk; j++) { + for (size_t i = 0; i < nq; i++) { + for (size_t j = 0; j < output_topk; j++) { output_ids[i * output_topk + j] = input_ids[i * input_topk + j]; - output_distance[i * output_topk + j] = input_distance[i * input_topk + j]; + output_distances[i * output_topk + j] = input_distances[i * input_topk + j]; } } } void -CheckTopkResult(const std::vector& input_ids_1, - const std::vector& input_distance_1, - const std::vector& input_ids_2, - const std::vector& input_distance_2, - uint64_t topk, - uint64_t nq, +CheckTopkResult(const ms::ResultIds& input_ids_1, + const ms::ResultDistances& input_distances_1, + size_t input_k_1, + const ms::ResultIds& input_ids_2, + const ms::ResultDistances& input_distances_2, + size_t input_k_2, + size_t topk, + size_t nq, bool ascending, - const milvus::scheduler::ResultSet& result) { - ASSERT_EQ(result.size(), nq); - ASSERT_EQ(input_ids_1.size(), input_distance_1.size()); - ASSERT_EQ(input_ids_2.size(), input_distance_2.size()); + const ms::ResultIds& result_ids, + const ms::ResultDistances& result_distances) { + ASSERT_EQ(result_ids.size(), result_distances.size()); + ASSERT_EQ(input_ids_1.size(), input_distances_1.size()); + ASSERT_EQ(input_ids_2.size(), input_distances_2.size()); - for (int64_t i = 0; i < nq; i++) { + size_t result_k = result_distances.size() / nq; + ASSERT_EQ(result_k, std::min(topk, input_k_1 + input_k_2)); + + for (size_t i = 0; i < nq; i++) { std::vector - src_vec(input_distance_1.begin() + i * topk, input_distance_1.begin() + (i + 1) * topk); + src_vec(input_distances_1.begin() + i * topk, input_distances_1.begin() + (i + 1) * topk); src_vec.insert(src_vec.end(), - input_distance_2.begin() + i * topk, - input_distance_2.begin() + (i + 1) * topk); + input_distances_2.begin() + i * topk, + input_distances_2.begin() + (i + 1) * topk); if (ascending) { std::sort(src_vec.begin(), src_vec.end()); } else { @@ -111,15 +118,16 @@ CheckTopkResult(const std::vector& input_ids_1, ++iter; } - uint64_t n = std::min(topk, result[i].size()); - for (uint64_t j = 0; j < n; j++) { - if (result[i][j].first < 0) { + size_t n = std::min(topk, result_ids.size() / nq); + for (size_t j = 0; j < n; j++) { + size_t idx = i * n + j; + if (result_ids[idx] < 0) { continue; } - if (src_vec[j] != result[i][j].second) { - std::cout << src_vec[j] << " " << result[i][j].second << std::endl; + if (src_vec[j] != result_distances[idx]) { + std::cout << src_vec[j] << " " << result_distances[idx] << std::endl; } - ASSERT_TRUE(src_vec[j] == result[i][j].second); + ASSERT_TRUE(src_vec[j] == result_distances[idx]); } } } @@ -127,20 +135,21 @@ CheckTopkResult(const std::vector& input_ids_1, } // namespace void -MergeTopkToResultSetTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { - std::vector ids1, ids2; - std::vector dist1, dist2; - ms::ResultSet result; +MergeTopkToResultSetTest(size_t topk_1, size_t topk_2, size_t nq, size_t topk, bool ascending) { + ms::ResultIds ids1, ids2; + ms::ResultDistances dist1, dist2; + ms::ResultIds result_ids; + ms::ResultDistances result_distances; BuildResult(ids1, dist1, topk_1, topk, nq, ascending); BuildResult(ids2, dist2, topk_2, topk, nq, ascending); - ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, topk_1, nq, topk, ascending, result); - ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, topk_2, nq, topk, ascending, result); - CheckTopkResult(ids1, dist1, ids2, dist2, topk, nq, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, topk_1, nq, topk, ascending, result_ids, result_distances); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, topk_2, nq, topk, ascending, result_ids, result_distances); + CheckTopkResult(ids1, dist1, topk_1, ids2, dist2, topk_2, topk, nq, ascending, result_ids, result_distances); } TEST(DBSearchTest, MERGE_RESULT_SET_TEST) { - uint64_t NQ = 15; - uint64_t TOP_K = 64; + size_t NQ = 15; + size_t TOP_K = 64; /* test1, id1/dist1 valid, id2/dist2 empty */ MergeTopkToResultSetTest(TOP_K, 0, NQ, TOP_K, true); @@ -159,21 +168,21 @@ TEST(DBSearchTest, MERGE_RESULT_SET_TEST) { MergeTopkToResultSetTest(TOP_K / 2, TOP_K / 3, NQ, TOP_K, false); } -//void MergeTopkArrayTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { +//void MergeTopkArrayTest(size_t topk_1, size_t topk_2, size_t nq, size_t topk, bool ascending) { // std::vector ids1, ids2; // std::vector dist1, dist2; // ms::ResultSet result; // BuildResult(ids1, dist1, topk_1, topk, nq, ascending); // BuildResult(ids2, dist2, topk_2, topk, nq, ascending); -// uint64_t result_topk = std::min(topk, topk_1 + topk_2); +// size_t result_topk = std::min(topk, topk_1 + topk_2); // ms::XSearchTask::MergeTopkArray(ids1, dist1, topk_1, ids2, dist2, topk_2, nq, topk, ascending); // if (ids1.size() != result_topk * nq) { // std::cout << ids1.size() << " " << result_topk * nq << std::endl; // } // ASSERT_TRUE(ids1.size() == result_topk * nq); // ASSERT_TRUE(dist1.size() == result_topk * nq); -// for (uint64_t i = 0; i < nq; i++) { -// for (uint64_t k = 1; k < result_topk; k++) { +// for (size_t i = 0; i < nq; i++) { +// for (size_t k = 1; k < result_topk; k++) { // float f0 = dist1[i * topk + k - 1]; // float f1 = dist1[i * topk + k]; // if (ascending) { @@ -192,8 +201,8 @@ TEST(DBSearchTest, MERGE_RESULT_SET_TEST) { //} //TEST(DBSearchTest, MERGE_ARRAY_TEST) { -// uint64_t NQ = 15; -// uint64_t TOP_K = 64; +// size_t NQ = 15; +// size_t TOP_K = 64; // // /* test1, id1/dist1 valid, id2/dist2 empty */ // MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, true); @@ -222,23 +231,23 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { int32_t index_file_num = 478; /* sift1B dataset, index files num */ bool ascending = true; - std::vector thread_vec = {4, 8}; - std::vector nq_vec = {1, 10, 100}; - std::vector topk_vec = {1, 4, 16, 64}; - int32_t NQ = nq_vec[nq_vec.size() - 1]; - int32_t TOPK = topk_vec[topk_vec.size() - 1]; + std::vector thread_vec = {4}; + std::vector nq_vec = {1000}; + std::vector topk_vec = {64}; + size_t NQ = nq_vec[nq_vec.size() - 1]; + size_t TOPK = topk_vec[topk_vec.size() - 1]; - std::vector> id_vec; - std::vector> dist_vec; - std::vector input_ids; - std::vector input_distance; + std::vector id_vec; + std::vector dist_vec; + ms::ResultIds input_ids; + ms::ResultDistances input_distances; int32_t i, k, step; /* generate testing data */ for (i = 0; i < index_file_num; i++) { - BuildResult(input_ids, input_distance, TOPK, TOPK, NQ, ascending); + BuildResult(input_ids, input_distances, TOPK, TOPK, NQ, ascending); id_vec.push_back(input_ids); - dist_vec.push_back(input_distance); + dist_vec.push_back(input_distances); } for (int32_t max_thread_num : thread_vec) { @@ -247,10 +256,11 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { for (int32_t nq : nq_vec) { for (int32_t top_k : topk_vec) { - ms::ResultSet final_result, final_result_2, final_result_3; + ms::ResultIds final_result_ids, final_result_ids_2, final_result_ids_3; + ms::ResultDistances final_result_distances, final_result_distances_2, final_result_distances_3; - std::vector> id_vec_1(index_file_num); - std::vector> dist_vec_1(index_file_num); + std::vector id_vec_1(index_file_num); + std::vector dist_vec_1(index_file_num); for (i = 0; i < index_file_num; i++) { CopyResult(id_vec_1[i], dist_vec_1[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); } @@ -268,8 +278,10 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { nq, top_k, ascending, - final_result); - ASSERT_EQ(final_result.size(), nq); + final_result_ids, + final_result_distances); + ASSERT_EQ(final_result_ids.size(), nq * top_k); + ASSERT_EQ(final_result_distances.size(), nq * top_k); } rc1.RecordSection("reduce done"); @@ -278,7 +290,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { // /* method-2 */ // std::vector> id_vec_2(index_file_num); // std::vector> dist_vec_2(index_file_num); -// std::vector k_vec_2(index_file_num); +// std::vector k_vec_2(index_file_num); // for (i = 0; i < index_file_num; i++) { // CopyResult(id_vec_2[i], dist_vec_2[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); // k_vec_2[i] = top_k; @@ -321,7 +333,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { // /* method-3 parallel */ // std::vector> id_vec_3(index_file_num); // std::vector> dist_vec_3(index_file_num); -// std::vector k_vec_3(index_file_num); +// std::vector k_vec_3(index_file_num); // for (i = 0; i < index_file_num; i++) { // CopyResult(id_vec_3[i], dist_vec_3[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); // k_vec_3[i] = top_k; diff --git a/core/unittest/metrics/test_metrics.cpp b/core/unittest/metrics/test_metrics.cpp index c0d1044bb4a6b492256e8e54e5daad4479d31897..1b26ad097b882a3bfbac09a7aa487012c88e9cef 100644 --- a/core/unittest/metrics/test_metrics.cpp +++ b/core/unittest/metrics/test_metrics.cpp @@ -75,7 +75,8 @@ TEST_F(MetricTest, METRIC_TEST) { } std::thread search([&]() { - milvus::engine::QueryResults results; +// milvus::engine::ResultIds result_ids; +// milvus::engine::ResultDistances result_distances; int k = 10; std::this_thread::sleep_for(std::chrono::seconds(2)); @@ -90,7 +91,7 @@ TEST_F(MetricTest, METRIC_TEST) { prev_count = count; START_TIMER; -// stat = db_->Query(group_name, k, qb, qxb, results); +// stat = db_->Query(group_name, k, qb, qxb, result_ids, result_distances); ss << "Search " << j << " With Size " << (float) (count * group_dim * sizeof(float)) / (1024 * 1024) << " M";