From b7bc62d454c308d9972e5c184788a58a159fae2c Mon Sep 17 00:00:00 2001 From: groot Date: Fri, 21 Jun 2019 10:32:46 +0800 Subject: [PATCH] refine code Former-commit-id: 17c00857221bc167525f7c340c99061697c0c547 --- cpp/src/db/DBImpl.cpp | 2 +- cpp/src/db/scheduler/task/SearchTask.cpp | 4 +- .../sdk/examples/simple/src/ClientTest.cpp | 23 ++++--- cpp/unittest/db/db_tests.cpp | 63 +++++++++---------- 4 files changed, 47 insertions(+), 45 deletions(-) diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index 8202c8a7..7642ea37 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -652,7 +652,7 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) { << index->PhysicalSize()/(1024*1024) << " M" << " from file " << to_remove.file_id_; - //index->Cache(); + index->Cache(); } catch (std::exception& ex) { return Status::Error("Build index encounter exception", ex.what()); diff --git a/cpp/src/db/scheduler/task/SearchTask.cpp b/cpp/src/db/scheduler/task/SearchTask.cpp index d8c37269..d04f2703 100644 --- a/cpp/src/db/scheduler/task/SearchTask.cpp +++ b/cpp/src/db/scheduler/task/SearchTask.cpp @@ -55,7 +55,7 @@ void MergeResult(SearchContext::Id2ScoreMap &score_src, while(true) { //all score_src items are merged, if score_merged.size() still less than topk //move items from score_target to score_merged until score_merged.size() equal topk - if(src_index >= src_count - 1) { + if(src_index >= src_count) { for(size_t i = target_index; i < target_count && score_merged.size() < topk; ++i) { score_merged.push_back(score_target[i]); } @@ -64,7 +64,7 @@ void MergeResult(SearchContext::Id2ScoreMap &score_src, //all score_target items are merged, if score_merged.size() still less than topk //move items from score_src to score_merged until score_merged.size() equal topk - if(target_index >= target_count - 1) { + if(target_index >= target_count) { for(size_t i = src_index; i < src_count && score_merged.size() < topk; ++i) { score_merged.push_back(score_src[i]); } diff --git a/cpp/src/sdk/examples/simple/src/ClientTest.cpp b/cpp/src/sdk/examples/simple/src/ClientTest.cpp index 3aad4e07..78145446 100644 --- a/cpp/src/sdk/examples/simple/src/ClientTest.cpp +++ b/cpp/src/sdk/examples/simple/src/ClientTest.cpp @@ -17,10 +17,11 @@ namespace { static const std::string TABLE_NAME = GetTableName(); static constexpr int64_t TABLE_DIMENSION = 512; - static constexpr int64_t TOTAL_ROW_COUNT = 100000; + static constexpr int64_t BATCH_ROW_COUNT = 100000; + static constexpr int64_t NQ = 10; static constexpr int64_t TOP_K = 10; static constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different - static constexpr int64_t ADD_VECTOR_LOOP = 10; + static constexpr int64_t ADD_VECTOR_LOOP = 5; #define BLOCK_SPLITER std::cout << "===========================================" << std::endl; @@ -96,7 +97,7 @@ namespace { TableSchema BuildTableSchema() { TableSchema tb_schema; tb_schema.table_name = TABLE_NAME; - tb_schema.index_type = IndexType::gpu_ivfflat; + tb_schema.index_type = IndexType::cpu_idmap; tb_schema.dimension = TABLE_DIMENSION; tb_schema.store_raw_vector = true; @@ -110,17 +111,21 @@ namespace { } vector_record_array.clear(); - for (int64_t k = from; k < to; k++) { RowRecord record; record.data.resize(TABLE_DIMENSION); for(int64_t i = 0; i < TABLE_DIMENSION; i++) { - record.data[i] = (float)(i + k); + record.data[i] = (float)(k%(i+1)); } vector_record_array.emplace_back(record); } } + + void Sleep(int seconds) { + std::cout << "Waiting " << seconds << " seconds ..." << std::endl; + sleep(seconds); + } } void @@ -171,7 +176,7 @@ ClientTest::Test(const std::string& address, const std::string& port) { for(int i = 0; i < ADD_VECTOR_LOOP; i++){//add vectors std::vector record_array; - BuildVectors(i*TOTAL_ROW_COUNT, (i+1)*TOTAL_ROW_COUNT, record_array); + BuildVectors(i*BATCH_ROW_COUNT, (i+1)*BATCH_ROW_COUNT, record_array); std::vector record_ids; Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids); std::cout << "AddVector function call status: " << stat.ToString() << std::endl; @@ -179,10 +184,10 @@ ClientTest::Test(const std::string& address, const std::string& port) { } {//search vectors - std::cout << "Waiting data persist. Sleep 1 seconds ..." << std::endl; - sleep(1); + Sleep(2); + std::vector record_array; - BuildVectors(SEARCH_TARGET, SEARCH_TARGET + 10, record_array); + BuildVectors(SEARCH_TARGET, SEARCH_TARGET + NQ, record_array); std::vector query_range_array; Range rg; diff --git a/cpp/unittest/db/db_tests.cpp b/cpp/unittest/db/db_tests.cpp index 00b6c7a9..6cfbe913 100644 --- a/cpp/unittest/db/db_tests.cpp +++ b/cpp/unittest/db/db_tests.cpp @@ -69,7 +69,7 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) { engine::meta::TableSchema group_info; group_info.dimension_ = group_dim; group_info.table_id_ = group_name; - group_info.engine_type_ = (int)engine::EngineType::FAISS_IVFFLAT; + group_info.engine_type_ = (int)engine::EngineType::FAISS_IDMAP; engine::Status stat = db_->CreateTable(group_info); engine::meta::TableSchema group_info_get; @@ -101,30 +101,27 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) { db_->Size(size); LOG(DEBUG) << "size=" << size; - ASSERT_TRUE(size < 1 * engine::meta::G); + ASSERT_LT(size, 1 * engine::meta::G); delete [] xb; }; TEST_F(DBTest, DB_TEST) { - - - - static const std::string group_name = "test_group"; - static const int group_dim = 256; - - engine::meta::TableSchema group_info; - group_info.dimension_ = group_dim; - group_info.table_id_ = group_name; - group_info.engine_type_ = (int)engine::EngineType::FAISS_IVFFLAT; - engine::Status stat = db_->CreateTable(group_info); - - engine::meta::TableSchema group_info_get; - group_info_get.table_id_ = group_name; - stat = db_->DescribeTable(group_info_get); + static const std::string table_name = "test_group"; + static const int table_dim = 256; + + engine::meta::TableSchema table_info; + table_info.dimension_ = table_dim; + table_info.table_id_ = table_name; + table_info.engine_type_ = (int)engine::EngineType::FAISS_IDMAP; + engine::Status stat = db_->CreateTable(table_info); + + engine::meta::TableSchema table_info_get; + table_info_get.table_id_ = table_name; + stat = db_->DescribeTable(table_info_get); ASSERT_STATS(stat); - ASSERT_EQ(group_info_get.dimension_, group_dim); + ASSERT_EQ(table_info_get.dimension_, table_dim); engine::IDNumbers vector_ids; engine::IDNumbers target_ids; @@ -160,7 +157,7 @@ TEST_F(DBTest, DB_TEST) { prev_count = count; START_TIMER; - stat = db_->Query(group_name, k, qb, qxb, results); + stat = db_->Query(table_name, k, qb, qxb, results); ss << "Search " << j << " With Size " << count/engine::meta::M << " M"; STOP_TIMER(ss.str()); @@ -183,10 +180,10 @@ TEST_F(DBTest, DB_TEST) { for (auto i=0; iInsertVectors(group_name, qb, qxb, target_ids); + db_->InsertVectors(table_name, qb, qxb, target_ids); ASSERT_EQ(target_ids.size(), qb); } else { - db_->InsertVectors(group_name, nb, xb, vector_ids); + db_->InsertVectors(table_name, nb, xb, vector_ids); } std::this_thread::sleep_for(std::chrono::microseconds(1)); } @@ -198,20 +195,20 @@ TEST_F(DBTest, DB_TEST) { }; TEST_F(DBTest, SEARCH_TEST) { - static const std::string group_name = "test_group"; + static const std::string table_name = "test_group"; static const int group_dim = 256; - engine::meta::TableSchema group_info; - group_info.dimension_ = group_dim; - group_info.table_id_ = group_name; - group_info.engine_type_ = (int)engine::EngineType::FAISS_IVFFLAT; - engine::Status stat = db_->CreateTable(group_info); + engine::meta::TableSchema table_info; + table_info.dimension_ = group_dim; + table_info.table_id_ = table_name; + table_info.engine_type_ = (int)engine::EngineType::FAISS_IDMAP; + engine::Status stat = db_->CreateTable(table_info); - engine::meta::TableSchema group_info_get; - group_info_get.table_id_ = group_name; - stat = db_->DescribeTable(group_info_get); + engine::meta::TableSchema table_info_get; + table_info_get.table_id_ = table_name; + stat = db_->DescribeTable(table_info_get); ASSERT_STATS(stat); - ASSERT_EQ(group_info_get.dimension_, group_dim); + ASSERT_EQ(table_info_get.dimension_, group_dim); // prepare raw data size_t nb = 250000; @@ -243,7 +240,7 @@ TEST_F(DBTest, SEARCH_TEST) { // insert data const int batch_size = 100; for (int j = 0; j < nb / batch_size; ++j) { - stat = db_->InsertVectors(group_name, batch_size, xb.data()+batch_size*j*group_dim, ids); + stat = db_->InsertVectors(table_name, batch_size, xb.data()+batch_size*j*group_dim, ids); if (j == 200){ sleep(1);} ASSERT_STATS(stat); } @@ -251,7 +248,7 @@ TEST_F(DBTest, SEARCH_TEST) { sleep(2); // wait until build index finish engine::QueryResults results; - stat = db_->Query(group_name, k, nq, xq.data(), results); + stat = db_->Query(table_name, k, nq, xq.data(), results); ASSERT_STATS(stat); // TODO(linxj): add groundTruth assert -- GitLab