diff --git a/cpp/src/db/engine/ExecutionEngine.h b/cpp/src/db/engine/ExecutionEngine.h index a122231a1dce46638a88e5e6fa2ed1e60fc9602f..88be75aeb99988d60033bda17e5fef4a4621ff16 100644 --- a/cpp/src/db/engine/ExecutionEngine.h +++ b/cpp/src/db/engine/ExecutionEngine.h @@ -51,6 +51,7 @@ public: virtual Status Search(long n, const float *data, long k, + long nprobe, float *distances, long *labels) const = 0; diff --git a/cpp/src/db/engine/ExecutionEngineImpl.cpp b/cpp/src/db/engine/ExecutionEngineImpl.cpp index 037c52cf5cc7c7e557ebbb9be3a3a2e86a05782a..dd38369832f3b94578cfc742dd71486bb33568fb 100644 --- a/cpp/src/db/engine/ExecutionEngineImpl.cpp +++ b/cpp/src/db/engine/ExecutionEngineImpl.cpp @@ -228,10 +228,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location) { Status ExecutionEngineImpl::Search(long n, const float *data, long k, + long nprobe, float *distances, long *labels) const { - ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe_; - auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe_}}); + ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe; + auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe}}); if (ec != server::KNOWHERE_SUCCESS) { ENGINE_LOG_ERROR << "Search error"; return Status::Error("Search: Search Error"); @@ -256,7 +257,6 @@ Status ExecutionEngineImpl::Init() { case EngineType::FAISS_IVFSQ8: case EngineType::FAISS_IVFFLAT: { ConfigNode engine_config = config.GetConfig(CONFIG_ENGINE); - nprobe_ = engine_config.GetInt32Value(CONFIG_NPROBE, 1); nlist_ = engine_config.GetInt32Value(CONFIG_NLIST, 16384); break; } diff --git a/cpp/src/db/engine/ExecutionEngineImpl.h b/cpp/src/db/engine/ExecutionEngineImpl.h index cb50af238b68bf57bdd55989995418c3d844d22e..948719310cd380a984fafc1d51c6a96a354b162a 100644 --- a/cpp/src/db/engine/ExecutionEngineImpl.h +++ b/cpp/src/db/engine/ExecutionEngineImpl.h @@ -51,6 +51,7 @@ public: Status Search(long n, const float *data, long k, + long nprobe, float *distances, long *labels) const override; @@ -73,7 +74,6 @@ protected: int64_t dim; std::string location_; - size_t nprobe_ = 0; size_t nlist_ = 0; int64_t gpu_num = 0; }; diff --git a/cpp/src/db/scheduler/context/SearchContext.h b/cpp/src/db/scheduler/context/SearchContext.h index 48ca7fbd8c37accca64d578895a444227ffd4091..9ca03e08309faf0cbe66c9004083c8c996898cb8 100644 --- a/cpp/src/db/scheduler/context/SearchContext.h +++ b/cpp/src/db/scheduler/context/SearchContext.h @@ -27,6 +27,7 @@ public: uint64_t topk() const { return topk_; } uint64_t nq() const { return nq_; } + uint64_t nprobe() const { return nprobe_; } const float* vectors() const { return vectors_; } using Id2IndexMap = std::unordered_map; @@ -53,7 +54,7 @@ public: private: uint64_t topk_ = 0; uint64_t nq_ = 0; - uint64_t nprobe_ = 0; + uint64_t nprobe_ = 10; const float* vectors_ = nullptr; Id2IndexMap map_index_files_; diff --git a/cpp/src/db/scheduler/task/SearchTask.cpp b/cpp/src/db/scheduler/task/SearchTask.cpp index 79baeeafe9f7d2d23e04a55c0cd2d9b4454ce97c..fd9d679d5e6ce2c761987f46ffb0f5cb1c1cd49c 100644 --- a/cpp/src/db/scheduler/task/SearchTask.cpp +++ b/cpp/src/db/scheduler/task/SearchTask.cpp @@ -109,12 +109,13 @@ std::shared_ptr SearchTask::Execute() { for(auto& context : search_contexts_) { //step 1: allocate memory auto inner_k = context->topk(); + auto nprobe = context->nprobe(); output_ids.resize(inner_k*context->nq()); output_distence.resize(inner_k*context->nq()); try { //step 2: search - index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(), + index_engine_->Search(context->nq(), context->vectors(), inner_k, nprobe, output_distence.data(), output_ids.data()); double span = rc.RecordSection("do search for context:" + context->Identity()); diff --git a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp index 8641e152eff6b804be525f11983ee15678af5efd..583a91789768d0f36ab00da1a4175ff6734651ee 100644 --- a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp +++ b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp @@ -5,6 +5,7 @@ ******************************************************************************/ #include "ClientTest.h" #include "MilvusApi.h" +#include "cache/CpuCacheMgr.h" #include #include @@ -23,7 +24,7 @@ namespace { constexpr int64_t NQ = 10; constexpr int64_t TOP_K = 10; constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different - constexpr int64_t ADD_VECTOR_LOOP = 5; + constexpr int64_t ADD_VECTOR_LOOP = 1; constexpr int64_t SECONDS_EACH_HOUR = 3600; #define BLOCK_SPLITER std::cout << "===========================================" << std::endl; @@ -174,7 +175,7 @@ namespace { std::vector topk_query_result_array; { TimeRecorder rc(phase_name); - Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 0, topk_query_result_array); + Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 10, topk_query_result_array); std::cout << "SearchVector function call status: " << stat.ToString() << std::endl; } @@ -316,6 +317,11 @@ ClientTest::Test(const std::string& address, const std::string& port) { // std::cout << "BuildIndex function call status: " << stat.ToString() << std::endl; } + {//preload table + Status stat = conn->PreloadTable(TABLE_NAME); + std::cout << "PreloadTable function call status: " << stat.ToString() << std::endl; + } + {//search vectors after build index finish DoSearch(conn, search_record_array, "Search after build index finish"); } diff --git a/cpp/src/sdk/interface/ConnectionImpl.cpp b/cpp/src/sdk/interface/ConnectionImpl.cpp index efee7023ef801bffc017432e61070fa1ad6ad5ea..b496d1c1049c4fb968c8aa59bb90e78ca15e12a1 100644 --- a/cpp/src/sdk/interface/ConnectionImpl.cpp +++ b/cpp/src/sdk/interface/ConnectionImpl.cpp @@ -122,7 +122,7 @@ ConnectionImpl::DeleteByRange(Range &range, Status ConnectionImpl::PreloadTable(const std::string &table_name) const { - + return client_proxy_->PreloadTable(table_name); } IndexParam diff --git a/cpp/unittest/db/db_tests.cpp b/cpp/unittest/db/db_tests.cpp index 07ae4bc964839bb4e6bd0297461eb1077b189a77..8b36d2efbdecf911cdeafa3926df44cbb23ee43e 100644 --- a/cpp/unittest/db/db_tests.cpp +++ b/cpp/unittest/db/db_tests.cpp @@ -8,6 +8,7 @@ #include "db/DBImpl.h" #include "db/meta/MetaConsts.h" #include "db/Factories.h" +#include "cache/CpuCacheMgr.h" #include #include @@ -128,7 +129,7 @@ TEST_F(DBTest, DB_TEST) { prev_count = count; START_TIMER; - stat = db_->Query(TABLE_NAME, k, qb, qxb.data(), results); + stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); ss << "Search " << j << " With Size " << count/engine::meta::M << " M"; STOP_TIMER(ss.str()); @@ -211,7 +212,7 @@ TEST_F(DBTest, SEARCH_TEST) { { engine::QueryResults results; - stat = db_->Query(TABLE_NAME, k, nq, xq.data(), results); + stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); ASSERT_STATS(stat); } @@ -219,7 +220,7 @@ TEST_F(DBTest, SEARCH_TEST) { engine::meta::DatesT dates; std::vector file_ids = {"4", "5", "6"}; engine::QueryResults results; - stat = db_->Query(TABLE_NAME, file_ids, k, nq, xq.data(), dates, results); + stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results); ASSERT_STATS(stat); } @@ -239,19 +240,19 @@ TEST_F(DBTest, PRELOADTABLE_TEST) { engine::IDNumbers vector_ids; engine::IDNumbers target_ids; - int64_t nb = 50; + int64_t nb = 100000; std::vector xb; BuildVectors(nb, xb); - int loop = INSERT_LOOP; + int loop = 5; for (auto i=0; iInsertVectors(TABLE_NAME, qb, qxb.data(), target_ids); - ASSERT_EQ(target_ids.size(), qb); + db_->InsertVectors(TABLE_NAME, nb, xb.data(), target_ids); + ASSERT_EQ(target_ids.size(), nb); } + db_->BuildIndex(TABLE_NAME); int64_t prev_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage(); - stat = db_->PreloadTable(TABLE_NAME); ASSERT_STATS(stat); int64_t cur_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage(); diff --git a/cpp/unittest/db/mem_test.cpp b/cpp/unittest/db/mem_test.cpp index dc0b9aa77be4d4830578f7e04022475f4230cfb4..1976822e761d20fc538d9cc6c0baecbe8e40fa56 100644 --- a/cpp/unittest/db/mem_test.cpp +++ b/cpp/unittest/db/mem_test.cpp @@ -243,7 +243,7 @@ TEST_F(NewMemManagerTest, SERIAL_INSERT_SEARCH_TEST) { for (auto &pair : search_vectors) { auto &search = pair.second; engine::QueryResults results; - stat = db_->Query(TABLE_NAME, k, 1, search.data(), results); + stat = db_->Query(TABLE_NAME, k, 1, 10, search.data(), results); ASSERT_EQ(results[0][0].first, pair.first); ASSERT_LT(results[0][0].second, 0.00001); } @@ -332,7 +332,7 @@ TEST_F(NewMemManagerTest, CONCURRENT_INSERT_SEARCH_TEST) { prev_count = count; START_TIMER; - stat = db_->Query(TABLE_NAME, k, qb, qxb.data(), results); + stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); ss << "Search " << j << " With Size " << count / engine::meta::M << " M"; STOP_TIMER(ss.str()); diff --git a/cpp/unittest/db/mysql_db_test.cpp b/cpp/unittest/db/mysql_db_test.cpp index 462721f186960c961e450465ea737d2e344954a5..78adf9f0f51a471b82a311e1e2a6ed5d9abd9819 100644 --- a/cpp/unittest/db/mysql_db_test.cpp +++ b/cpp/unittest/db/mysql_db_test.cpp @@ -90,7 +90,7 @@ TEST_F(DISABLED_MySQLDBTest, DB_TEST) { prev_count = count; START_TIMER; - stat = db_->Query(TABLE_NAME, k, qb, qxb.data(), results); + stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); ss << "Search " << j << " With Size " << count/engine::meta::M << " M"; STOP_TIMER(ss.str()); @@ -190,7 +190,7 @@ TEST_F(DISABLED_MySQLDBTest, SEARCH_TEST) { sleep(2); // wait until build index finish engine::QueryResults results; - stat = db_->Query(TABLE_NAME, k, nq, xq.data(), results); + stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); ASSERT_STATS(stat); delete db_; diff --git a/cpp/unittest/db/scheduler_test.cpp b/cpp/unittest/db/scheduler_test.cpp index 01a7057e001d152cf6c5b6dfc836ce55c20c0224..0937ef197acda72d5f2a5bac24e986c564ee43f8 100644 --- a/cpp/unittest/db/scheduler_test.cpp +++ b/cpp/unittest/db/scheduler_test.cpp @@ -38,7 +38,7 @@ TEST(DBSchedulerTest, TASK_QUEUE_TEST) { ASSERT_EQ(ptr, nullptr); ASSERT_TRUE(queue.Empty()); - engine::SearchContextPtr context_ptr = std::make_shared(1, 1, nullptr); + engine::SearchContextPtr context_ptr = std::make_shared(1, 1, 10, nullptr); for(size_t i = 0; i < 10; i++) { auto file = CreateTabileFileStruct(i, "tbl"); context_ptr->AddIndexFile(file); @@ -69,7 +69,7 @@ TEST(DBSchedulerTest, SEARCH_SCHEDULER_TEST) { task_list.push_back(task_ptr); } - engine::SearchContextPtr context_ptr = std::make_shared(1, 1, nullptr); + engine::SearchContextPtr context_ptr = std::make_shared(1, 1, 10, nullptr); for(size_t i = 0; i < 20; i++) { auto file = CreateTabileFileStruct(i, "tbl"); context_ptr->AddIndexFile(file);