提交 fabcbe91 编写于 作者: Y Yu Kun

fix preloadtable unittest bug


Former-commit-id: 8d19429bb3d4d58a1fcc31700dcc74685188280e
上级 536daeea
...@@ -51,6 +51,7 @@ public: ...@@ -51,6 +51,7 @@ public:
virtual Status Search(long n, virtual Status Search(long n,
const float *data, const float *data,
long k, long k,
long nprobe,
float *distances, float *distances,
long *labels) const = 0; long *labels) const = 0;
......
...@@ -228,10 +228,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location) { ...@@ -228,10 +228,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location) {
Status ExecutionEngineImpl::Search(long n, Status ExecutionEngineImpl::Search(long n,
const float *data, const float *data,
long k, long k,
long nprobe,
float *distances, float *distances,
long *labels) const { long *labels) const {
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe_; ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe_}}); auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe}});
if (ec != server::KNOWHERE_SUCCESS) { if (ec != server::KNOWHERE_SUCCESS) {
ENGINE_LOG_ERROR << "Search error"; ENGINE_LOG_ERROR << "Search error";
return Status::Error("Search: Search Error"); return Status::Error("Search: Search Error");
...@@ -256,7 +257,6 @@ Status ExecutionEngineImpl::Init() { ...@@ -256,7 +257,6 @@ Status ExecutionEngineImpl::Init() {
case EngineType::FAISS_IVFSQ8: case EngineType::FAISS_IVFSQ8:
case EngineType::FAISS_IVFFLAT: { case EngineType::FAISS_IVFFLAT: {
ConfigNode engine_config = config.GetConfig(CONFIG_ENGINE); ConfigNode engine_config = config.GetConfig(CONFIG_ENGINE);
nprobe_ = engine_config.GetInt32Value(CONFIG_NPROBE, 1);
nlist_ = engine_config.GetInt32Value(CONFIG_NLIST, 16384); nlist_ = engine_config.GetInt32Value(CONFIG_NLIST, 16384);
break; break;
} }
......
...@@ -51,6 +51,7 @@ public: ...@@ -51,6 +51,7 @@ public:
Status Search(long n, Status Search(long n,
const float *data, const float *data,
long k, long k,
long nprobe,
float *distances, float *distances,
long *labels) const override; long *labels) const override;
...@@ -73,7 +74,6 @@ protected: ...@@ -73,7 +74,6 @@ protected:
int64_t dim; int64_t dim;
std::string location_; std::string location_;
size_t nprobe_ = 0;
size_t nlist_ = 0; size_t nlist_ = 0;
int64_t gpu_num = 0; int64_t gpu_num = 0;
}; };
......
...@@ -27,6 +27,7 @@ public: ...@@ -27,6 +27,7 @@ public:
uint64_t topk() const { return topk_; } uint64_t topk() const { return topk_; }
uint64_t nq() const { return nq_; } uint64_t nq() const { return nq_; }
uint64_t nprobe() const { return nprobe_; }
const float* vectors() const { return vectors_; } const float* vectors() const { return vectors_; }
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>; using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
...@@ -53,7 +54,7 @@ public: ...@@ -53,7 +54,7 @@ public:
private: private:
uint64_t topk_ = 0; uint64_t topk_ = 0;
uint64_t nq_ = 0; uint64_t nq_ = 0;
uint64_t nprobe_ = 0; uint64_t nprobe_ = 10;
const float* vectors_ = nullptr; const float* vectors_ = nullptr;
Id2IndexMap map_index_files_; Id2IndexMap map_index_files_;
......
...@@ -109,12 +109,13 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() { ...@@ -109,12 +109,13 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
for(auto& context : search_contexts_) { for(auto& context : search_contexts_) {
//step 1: allocate memory //step 1: allocate memory
auto inner_k = context->topk(); auto inner_k = context->topk();
auto nprobe = context->nprobe();
output_ids.resize(inner_k*context->nq()); output_ids.resize(inner_k*context->nq());
output_distence.resize(inner_k*context->nq()); output_distence.resize(inner_k*context->nq());
try { try {
//step 2: search //step 2: search
index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(), index_engine_->Search(context->nq(), context->vectors(), inner_k, nprobe, output_distence.data(),
output_ids.data()); output_ids.data());
double span = rc.RecordSection("do search for context:" + context->Identity()); double span = rc.RecordSection("do search for context:" + context->Identity());
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
******************************************************************************/ ******************************************************************************/
#include "ClientTest.h" #include "ClientTest.h"
#include "MilvusApi.h" #include "MilvusApi.h"
#include "cache/CpuCacheMgr.h"
#include <iostream> #include <iostream>
#include <time.h> #include <time.h>
...@@ -23,7 +24,7 @@ namespace { ...@@ -23,7 +24,7 @@ namespace {
constexpr int64_t NQ = 10; constexpr int64_t NQ = 10;
constexpr int64_t TOP_K = 10; constexpr int64_t TOP_K = 10;
constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different
constexpr int64_t ADD_VECTOR_LOOP = 5; constexpr int64_t ADD_VECTOR_LOOP = 1;
constexpr int64_t SECONDS_EACH_HOUR = 3600; constexpr int64_t SECONDS_EACH_HOUR = 3600;
#define BLOCK_SPLITER std::cout << "===========================================" << std::endl; #define BLOCK_SPLITER std::cout << "===========================================" << std::endl;
...@@ -174,7 +175,7 @@ namespace { ...@@ -174,7 +175,7 @@ namespace {
std::vector<TopKQueryResult> topk_query_result_array; std::vector<TopKQueryResult> topk_query_result_array;
{ {
TimeRecorder rc(phase_name); TimeRecorder rc(phase_name);
Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 0, topk_query_result_array); Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 10, topk_query_result_array);
std::cout << "SearchVector function call status: " << stat.ToString() << std::endl; std::cout << "SearchVector function call status: " << stat.ToString() << std::endl;
} }
...@@ -316,6 +317,11 @@ ClientTest::Test(const std::string& address, const std::string& port) { ...@@ -316,6 +317,11 @@ ClientTest::Test(const std::string& address, const std::string& port) {
// std::cout << "BuildIndex function call status: " << stat.ToString() << std::endl; // std::cout << "BuildIndex function call status: " << stat.ToString() << std::endl;
} }
{//preload table
Status stat = conn->PreloadTable(TABLE_NAME);
std::cout << "PreloadTable function call status: " << stat.ToString() << std::endl;
}
{//search vectors after build index finish {//search vectors after build index finish
DoSearch(conn, search_record_array, "Search after build index finish"); DoSearch(conn, search_record_array, "Search after build index finish");
} }
......
...@@ -122,7 +122,7 @@ ConnectionImpl::DeleteByRange(Range &range, ...@@ -122,7 +122,7 @@ ConnectionImpl::DeleteByRange(Range &range,
Status Status
ConnectionImpl::PreloadTable(const std::string &table_name) const { ConnectionImpl::PreloadTable(const std::string &table_name) const {
return client_proxy_->PreloadTable(table_name);
} }
IndexParam IndexParam
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "db/DBImpl.h" #include "db/DBImpl.h"
#include "db/meta/MetaConsts.h" #include "db/meta/MetaConsts.h"
#include "db/Factories.h" #include "db/Factories.h"
#include "cache/CpuCacheMgr.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <easylogging++.h> #include <easylogging++.h>
...@@ -128,7 +129,7 @@ TEST_F(DBTest, DB_TEST) { ...@@ -128,7 +129,7 @@ TEST_F(DBTest, DB_TEST) {
prev_count = count; prev_count = count;
START_TIMER; START_TIMER;
stat = db_->Query(TABLE_NAME, k, qb, qxb.data(), results); stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results);
ss << "Search " << j << " With Size " << count/engine::meta::M << " M"; ss << "Search " << j << " With Size " << count/engine::meta::M << " M";
STOP_TIMER(ss.str()); STOP_TIMER(ss.str());
...@@ -211,7 +212,7 @@ TEST_F(DBTest, SEARCH_TEST) { ...@@ -211,7 +212,7 @@ TEST_F(DBTest, SEARCH_TEST) {
{ {
engine::QueryResults results; engine::QueryResults results;
stat = db_->Query(TABLE_NAME, k, nq, xq.data(), results); stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results);
ASSERT_STATS(stat); ASSERT_STATS(stat);
} }
...@@ -219,7 +220,7 @@ TEST_F(DBTest, SEARCH_TEST) { ...@@ -219,7 +220,7 @@ TEST_F(DBTest, SEARCH_TEST) {
engine::meta::DatesT dates; engine::meta::DatesT dates;
std::vector<std::string> file_ids = {"4", "5", "6"}; std::vector<std::string> file_ids = {"4", "5", "6"};
engine::QueryResults results; engine::QueryResults results;
stat = db_->Query(TABLE_NAME, file_ids, k, nq, xq.data(), dates, results); stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results);
ASSERT_STATS(stat); ASSERT_STATS(stat);
} }
...@@ -239,19 +240,19 @@ TEST_F(DBTest, PRELOADTABLE_TEST) { ...@@ -239,19 +240,19 @@ TEST_F(DBTest, PRELOADTABLE_TEST) {
engine::IDNumbers vector_ids; engine::IDNumbers vector_ids;
engine::IDNumbers target_ids; engine::IDNumbers target_ids;
int64_t nb = 50; int64_t nb = 100000;
std::vector<float> xb; std::vector<float> xb;
BuildVectors(nb, xb); BuildVectors(nb, xb);
int loop = INSERT_LOOP; int loop = 5;
for (auto i=0; i<loop; ++i) { for (auto i=0; i<loop; ++i) {
db_->InsertVectors(TABLE_NAME, qb, qxb.data(), target_ids); db_->InsertVectors(TABLE_NAME, nb, xb.data(), target_ids);
ASSERT_EQ(target_ids.size(), qb); ASSERT_EQ(target_ids.size(), nb);
} }
db_->BuildIndex(TABLE_NAME);
int64_t prev_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage(); int64_t prev_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage();
stat = db_->PreloadTable(TABLE_NAME); stat = db_->PreloadTable(TABLE_NAME);
ASSERT_STATS(stat); ASSERT_STATS(stat);
int64_t cur_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage(); int64_t cur_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage();
......
...@@ -243,7 +243,7 @@ TEST_F(NewMemManagerTest, SERIAL_INSERT_SEARCH_TEST) { ...@@ -243,7 +243,7 @@ TEST_F(NewMemManagerTest, SERIAL_INSERT_SEARCH_TEST) {
for (auto &pair : search_vectors) { for (auto &pair : search_vectors) {
auto &search = pair.second; auto &search = pair.second;
engine::QueryResults results; engine::QueryResults results;
stat = db_->Query(TABLE_NAME, k, 1, search.data(), results); stat = db_->Query(TABLE_NAME, k, 1, 10, search.data(), results);
ASSERT_EQ(results[0][0].first, pair.first); ASSERT_EQ(results[0][0].first, pair.first);
ASSERT_LT(results[0][0].second, 0.00001); ASSERT_LT(results[0][0].second, 0.00001);
} }
...@@ -332,7 +332,7 @@ TEST_F(NewMemManagerTest, CONCURRENT_INSERT_SEARCH_TEST) { ...@@ -332,7 +332,7 @@ TEST_F(NewMemManagerTest, CONCURRENT_INSERT_SEARCH_TEST) {
prev_count = count; prev_count = count;
START_TIMER; START_TIMER;
stat = db_->Query(TABLE_NAME, k, qb, qxb.data(), results); stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results);
ss << "Search " << j << " With Size " << count / engine::meta::M << " M"; ss << "Search " << j << " With Size " << count / engine::meta::M << " M";
STOP_TIMER(ss.str()); STOP_TIMER(ss.str());
......
...@@ -90,7 +90,7 @@ TEST_F(DISABLED_MySQLDBTest, DB_TEST) { ...@@ -90,7 +90,7 @@ TEST_F(DISABLED_MySQLDBTest, DB_TEST) {
prev_count = count; prev_count = count;
START_TIMER; START_TIMER;
stat = db_->Query(TABLE_NAME, k, qb, qxb.data(), results); stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results);
ss << "Search " << j << " With Size " << count/engine::meta::M << " M"; ss << "Search " << j << " With Size " << count/engine::meta::M << " M";
STOP_TIMER(ss.str()); STOP_TIMER(ss.str());
...@@ -190,7 +190,7 @@ TEST_F(DISABLED_MySQLDBTest, SEARCH_TEST) { ...@@ -190,7 +190,7 @@ TEST_F(DISABLED_MySQLDBTest, SEARCH_TEST) {
sleep(2); // wait until build index finish sleep(2); // wait until build index finish
engine::QueryResults results; engine::QueryResults results;
stat = db_->Query(TABLE_NAME, k, nq, xq.data(), results); stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results);
ASSERT_STATS(stat); ASSERT_STATS(stat);
delete db_; delete db_;
......
...@@ -38,7 +38,7 @@ TEST(DBSchedulerTest, TASK_QUEUE_TEST) { ...@@ -38,7 +38,7 @@ TEST(DBSchedulerTest, TASK_QUEUE_TEST) {
ASSERT_EQ(ptr, nullptr); ASSERT_EQ(ptr, nullptr);
ASSERT_TRUE(queue.Empty()); ASSERT_TRUE(queue.Empty());
engine::SearchContextPtr context_ptr = std::make_shared<engine::SearchContext>(1, 1, nullptr); engine::SearchContextPtr context_ptr = std::make_shared<engine::SearchContext>(1, 1, 10, nullptr);
for(size_t i = 0; i < 10; i++) { for(size_t i = 0; i < 10; i++) {
auto file = CreateTabileFileStruct(i, "tbl"); auto file = CreateTabileFileStruct(i, "tbl");
context_ptr->AddIndexFile(file); context_ptr->AddIndexFile(file);
...@@ -69,7 +69,7 @@ TEST(DBSchedulerTest, SEARCH_SCHEDULER_TEST) { ...@@ -69,7 +69,7 @@ TEST(DBSchedulerTest, SEARCH_SCHEDULER_TEST) {
task_list.push_back(task_ptr); task_list.push_back(task_ptr);
} }
engine::SearchContextPtr context_ptr = std::make_shared<engine::SearchContext>(1, 1, nullptr); engine::SearchContextPtr context_ptr = std::make_shared<engine::SearchContext>(1, 1, 10, nullptr);
for(size_t i = 0; i < 20; i++) { for(size_t i = 0; i < 20; i++) {
auto file = CreateTabileFileStruct(i, "tbl"); auto file = CreateTabileFileStruct(i, "tbl");
context_ptr->AddIndexFile(file); context_ptr->AddIndexFile(file);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册