提交 b7bc62d4 编写于 作者: G groot

refine code


Former-commit-id: 17c00857221bc167525f7c340c99061697c0c547
上级 ec42c86e
...@@ -652,7 +652,7 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) { ...@@ -652,7 +652,7 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) {
<< index->PhysicalSize()/(1024*1024) << " M" << index->PhysicalSize()/(1024*1024) << " M"
<< " from file " << to_remove.file_id_; << " from file " << to_remove.file_id_;
//index->Cache(); index->Cache();
} catch (std::exception& ex) { } catch (std::exception& ex) {
return Status::Error("Build index encounter exception", ex.what()); return Status::Error("Build index encounter exception", ex.what());
......
...@@ -55,7 +55,7 @@ void MergeResult(SearchContext::Id2ScoreMap &score_src, ...@@ -55,7 +55,7 @@ void MergeResult(SearchContext::Id2ScoreMap &score_src,
while(true) { while(true) {
//all score_src items are merged, if score_merged.size() still less than topk //all score_src items are merged, if score_merged.size() still less than topk
//move items from score_target to score_merged until score_merged.size() equal topk //move items from score_target to score_merged until score_merged.size() equal topk
if(src_index >= src_count - 1) { if(src_index >= src_count) {
for(size_t i = target_index; i < target_count && score_merged.size() < topk; ++i) { for(size_t i = target_index; i < target_count && score_merged.size() < topk; ++i) {
score_merged.push_back(score_target[i]); score_merged.push_back(score_target[i]);
} }
...@@ -64,7 +64,7 @@ void MergeResult(SearchContext::Id2ScoreMap &score_src, ...@@ -64,7 +64,7 @@ void MergeResult(SearchContext::Id2ScoreMap &score_src,
//all score_target items are merged, if score_merged.size() still less than topk //all score_target items are merged, if score_merged.size() still less than topk
//move items from score_src to score_merged until score_merged.size() equal topk //move items from score_src to score_merged until score_merged.size() equal topk
if(target_index >= target_count - 1) { if(target_index >= target_count) {
for(size_t i = src_index; i < src_count && score_merged.size() < topk; ++i) { for(size_t i = src_index; i < src_count && score_merged.size() < topk; ++i) {
score_merged.push_back(score_src[i]); score_merged.push_back(score_src[i]);
} }
......
...@@ -17,10 +17,11 @@ namespace { ...@@ -17,10 +17,11 @@ namespace {
static const std::string TABLE_NAME = GetTableName(); static const std::string TABLE_NAME = GetTableName();
static constexpr int64_t TABLE_DIMENSION = 512; static constexpr int64_t TABLE_DIMENSION = 512;
static constexpr int64_t TOTAL_ROW_COUNT = 100000; static constexpr int64_t BATCH_ROW_COUNT = 100000;
static constexpr int64_t NQ = 10;
static constexpr int64_t TOP_K = 10; static constexpr int64_t TOP_K = 10;
static constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different static constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different
static constexpr int64_t ADD_VECTOR_LOOP = 10; static constexpr int64_t ADD_VECTOR_LOOP = 5;
#define BLOCK_SPLITER std::cout << "===========================================" << std::endl; #define BLOCK_SPLITER std::cout << "===========================================" << std::endl;
...@@ -96,7 +97,7 @@ namespace { ...@@ -96,7 +97,7 @@ namespace {
TableSchema BuildTableSchema() { TableSchema BuildTableSchema() {
TableSchema tb_schema; TableSchema tb_schema;
tb_schema.table_name = TABLE_NAME; tb_schema.table_name = TABLE_NAME;
tb_schema.index_type = IndexType::gpu_ivfflat; tb_schema.index_type = IndexType::cpu_idmap;
tb_schema.dimension = TABLE_DIMENSION; tb_schema.dimension = TABLE_DIMENSION;
tb_schema.store_raw_vector = true; tb_schema.store_raw_vector = true;
...@@ -110,17 +111,21 @@ namespace { ...@@ -110,17 +111,21 @@ namespace {
} }
vector_record_array.clear(); vector_record_array.clear();
for (int64_t k = from; k < to; k++) { for (int64_t k = from; k < to; k++) {
RowRecord record; RowRecord record;
record.data.resize(TABLE_DIMENSION); record.data.resize(TABLE_DIMENSION);
for(int64_t i = 0; i < TABLE_DIMENSION; i++) { for(int64_t i = 0; i < TABLE_DIMENSION; i++) {
record.data[i] = (float)(i + k); record.data[i] = (float)(k%(i+1));
} }
vector_record_array.emplace_back(record); vector_record_array.emplace_back(record);
} }
} }
void Sleep(int seconds) {
std::cout << "Waiting " << seconds << " seconds ..." << std::endl;
sleep(seconds);
}
} }
void void
...@@ -171,7 +176,7 @@ ClientTest::Test(const std::string& address, const std::string& port) { ...@@ -171,7 +176,7 @@ ClientTest::Test(const std::string& address, const std::string& port) {
for(int i = 0; i < ADD_VECTOR_LOOP; i++){//add vectors for(int i = 0; i < ADD_VECTOR_LOOP; i++){//add vectors
std::vector<RowRecord> record_array; std::vector<RowRecord> record_array;
BuildVectors(i*TOTAL_ROW_COUNT, (i+1)*TOTAL_ROW_COUNT, record_array); BuildVectors(i*BATCH_ROW_COUNT, (i+1)*BATCH_ROW_COUNT, record_array);
std::vector<int64_t> record_ids; std::vector<int64_t> record_ids;
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids); Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
std::cout << "AddVector function call status: " << stat.ToString() << std::endl; std::cout << "AddVector function call status: " << stat.ToString() << std::endl;
...@@ -179,10 +184,10 @@ ClientTest::Test(const std::string& address, const std::string& port) { ...@@ -179,10 +184,10 @@ ClientTest::Test(const std::string& address, const std::string& port) {
} }
{//search vectors {//search vectors
std::cout << "Waiting data persist. Sleep 1 seconds ..." << std::endl; Sleep(2);
sleep(1);
std::vector<RowRecord> record_array; std::vector<RowRecord> record_array;
BuildVectors(SEARCH_TARGET, SEARCH_TARGET + 10, record_array); BuildVectors(SEARCH_TARGET, SEARCH_TARGET + NQ, record_array);
std::vector<Range> query_range_array; std::vector<Range> query_range_array;
Range rg; Range rg;
......
...@@ -69,7 +69,7 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) { ...@@ -69,7 +69,7 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) {
engine::meta::TableSchema group_info; engine::meta::TableSchema group_info;
group_info.dimension_ = group_dim; group_info.dimension_ = group_dim;
group_info.table_id_ = group_name; group_info.table_id_ = group_name;
group_info.engine_type_ = (int)engine::EngineType::FAISS_IVFFLAT; group_info.engine_type_ = (int)engine::EngineType::FAISS_IDMAP;
engine::Status stat = db_->CreateTable(group_info); engine::Status stat = db_->CreateTable(group_info);
engine::meta::TableSchema group_info_get; engine::meta::TableSchema group_info_get;
...@@ -101,30 +101,27 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) { ...@@ -101,30 +101,27 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) {
db_->Size(size); db_->Size(size);
LOG(DEBUG) << "size=" << size; LOG(DEBUG) << "size=" << size;
ASSERT_TRUE(size < 1 * engine::meta::G); ASSERT_LT(size, 1 * engine::meta::G);
delete [] xb; delete [] xb;
}; };
TEST_F(DBTest, DB_TEST) { TEST_F(DBTest, DB_TEST) {
static const std::string table_name = "test_group";
static const int table_dim = 256;
static const std::string group_name = "test_group"; engine::meta::TableSchema table_info;
static const int group_dim = 256; table_info.dimension_ = table_dim;
table_info.table_id_ = table_name;
engine::meta::TableSchema group_info; table_info.engine_type_ = (int)engine::EngineType::FAISS_IDMAP;
group_info.dimension_ = group_dim; engine::Status stat = db_->CreateTable(table_info);
group_info.table_id_ = group_name;
group_info.engine_type_ = (int)engine::EngineType::FAISS_IVFFLAT; engine::meta::TableSchema table_info_get;
engine::Status stat = db_->CreateTable(group_info); table_info_get.table_id_ = table_name;
stat = db_->DescribeTable(table_info_get);
engine::meta::TableSchema group_info_get;
group_info_get.table_id_ = group_name;
stat = db_->DescribeTable(group_info_get);
ASSERT_STATS(stat); ASSERT_STATS(stat);
ASSERT_EQ(group_info_get.dimension_, group_dim); ASSERT_EQ(table_info_get.dimension_, table_dim);
engine::IDNumbers vector_ids; engine::IDNumbers vector_ids;
engine::IDNumbers target_ids; engine::IDNumbers target_ids;
...@@ -160,7 +157,7 @@ TEST_F(DBTest, DB_TEST) { ...@@ -160,7 +157,7 @@ TEST_F(DBTest, DB_TEST) {
prev_count = count; prev_count = count;
START_TIMER; START_TIMER;
stat = db_->Query(group_name, k, qb, qxb, results); stat = db_->Query(table_name, k, qb, qxb, results);
ss << "Search " << j << " With Size " << count/engine::meta::M << " M"; ss << "Search " << j << " With Size " << count/engine::meta::M << " M";
STOP_TIMER(ss.str()); STOP_TIMER(ss.str());
...@@ -183,10 +180,10 @@ TEST_F(DBTest, DB_TEST) { ...@@ -183,10 +180,10 @@ TEST_F(DBTest, DB_TEST) {
for (auto i=0; i<loop; ++i) { for (auto i=0; i<loop; ++i) {
if (i==40) { if (i==40) {
db_->InsertVectors(group_name, qb, qxb, target_ids); db_->InsertVectors(table_name, qb, qxb, target_ids);
ASSERT_EQ(target_ids.size(), qb); ASSERT_EQ(target_ids.size(), qb);
} else { } else {
db_->InsertVectors(group_name, nb, xb, vector_ids); db_->InsertVectors(table_name, nb, xb, vector_ids);
} }
std::this_thread::sleep_for(std::chrono::microseconds(1)); std::this_thread::sleep_for(std::chrono::microseconds(1));
} }
...@@ -198,20 +195,20 @@ TEST_F(DBTest, DB_TEST) { ...@@ -198,20 +195,20 @@ TEST_F(DBTest, DB_TEST) {
}; };
TEST_F(DBTest, SEARCH_TEST) { TEST_F(DBTest, SEARCH_TEST) {
static const std::string group_name = "test_group"; static const std::string table_name = "test_group";
static const int group_dim = 256; static const int group_dim = 256;
engine::meta::TableSchema group_info; engine::meta::TableSchema table_info;
group_info.dimension_ = group_dim; table_info.dimension_ = group_dim;
group_info.table_id_ = group_name; table_info.table_id_ = table_name;
group_info.engine_type_ = (int)engine::EngineType::FAISS_IVFFLAT; table_info.engine_type_ = (int)engine::EngineType::FAISS_IDMAP;
engine::Status stat = db_->CreateTable(group_info); engine::Status stat = db_->CreateTable(table_info);
engine::meta::TableSchema group_info_get; engine::meta::TableSchema table_info_get;
group_info_get.table_id_ = group_name; table_info_get.table_id_ = table_name;
stat = db_->DescribeTable(group_info_get); stat = db_->DescribeTable(table_info_get);
ASSERT_STATS(stat); ASSERT_STATS(stat);
ASSERT_EQ(group_info_get.dimension_, group_dim); ASSERT_EQ(table_info_get.dimension_, group_dim);
// prepare raw data // prepare raw data
size_t nb = 250000; size_t nb = 250000;
...@@ -243,7 +240,7 @@ TEST_F(DBTest, SEARCH_TEST) { ...@@ -243,7 +240,7 @@ TEST_F(DBTest, SEARCH_TEST) {
// insert data // insert data
const int batch_size = 100; const int batch_size = 100;
for (int j = 0; j < nb / batch_size; ++j) { for (int j = 0; j < nb / batch_size; ++j) {
stat = db_->InsertVectors(group_name, batch_size, xb.data()+batch_size*j*group_dim, ids); stat = db_->InsertVectors(table_name, batch_size, xb.data()+batch_size*j*group_dim, ids);
if (j == 200){ sleep(1);} if (j == 200){ sleep(1);}
ASSERT_STATS(stat); ASSERT_STATS(stat);
} }
...@@ -251,7 +248,7 @@ TEST_F(DBTest, SEARCH_TEST) { ...@@ -251,7 +248,7 @@ TEST_F(DBTest, SEARCH_TEST) {
sleep(2); // wait until build index finish sleep(2); // wait until build index finish
engine::QueryResults results; engine::QueryResults results;
stat = db_->Query(group_name, k, nq, xq.data(), results); stat = db_->Query(table_name, k, nq, xq.data(), results);
ASSERT_STATS(stat); ASSERT_STATS(stat);
// TODO(linxj): add groundTruth assert // TODO(linxj): add groundTruth assert
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册