提交 608aa1b7 编写于 作者: Y Yu Kun

add nprobe and preloadtable unittest


Former-commit-id: dfaa1c4b36068c9de7499b48094ce39021d1213b
上级 4b887cf0
......@@ -33,14 +33,14 @@ public:
virtual Status InsertVectors(const std::string& table_id_,
uint64_t n, const float* vectors, IDNumbers& vector_ids_) = 0;
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq,
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, QueryResults& results) = 0;
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq,
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, QueryResults& results) = 0;
virtual Status Query(const std::string& table_id, const std::vector<std::string>& file_ids,
uint64_t k, uint64_t nq, const float* vectors,
uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) = 0;
virtual Status Size(uint64_t& result) = 0;
......
......@@ -189,11 +189,11 @@ Status DBImpl::InsertVectors(const std::string& table_id_,
}
Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq,
Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float *vectors, QueryResults &results) {
auto start_time = METRICS_NOW_TIME;
meta::DatesT dates = {meta::Meta::GetDate()};
Status result = Query(table_id, k, nq, vectors, dates, results);
Status result = Query(table_id, k, nq, nprobe, vectors, dates, results);
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time,end_time);
......@@ -202,7 +202,7 @@ Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq,
return result;
}
Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq,
Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, QueryResults& results) {
ENGINE_LOG_DEBUG << "Query by vectors";
......@@ -219,13 +219,13 @@ Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq,
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info before query
status = QueryAsync(table_id, file_id_array, k, nq, vectors, dates, results);
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, dates, results);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info after query
return status;
}
Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ids,
uint64_t k, uint64_t nq, const float* vectors,
uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) {
ENGINE_LOG_DEBUG << "Query by file ids";
......@@ -256,20 +256,20 @@ Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info before query
status = QueryAsync(table_id, file_id_array, k, nq, vectors, dates, results);
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, dates, results);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info after query
return status;
}
Status DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files,
uint64_t k, uint64_t nq, const float* vectors,
uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) {
auto start_time = METRICS_NOW_TIME;
server::TimeRecorder rc("");
//step 1: get files to search
ENGINE_LOG_DEBUG << "Engine query begin, index file count:" << files.size() << " date range count:" << dates.size();
SearchContextPtr context = std::make_shared<SearchContext>(k, nq, vectors);
SearchContextPtr context = std::make_shared<SearchContext>(k, nq, nprobe, vectors);
for (auto &file : files) {
TableFileSchemaPtr file_ptr = std::make_shared<meta::TableFileSchema>(file);
context->AddIndexFile(file_ptr);
......
......@@ -61,12 +61,18 @@ class DBImpl : public DB {
InsertVectors(const std::string &table_id, uint64_t n, const float *vectors, IDNumbers &vector_ids) override;
Status
Query(const std::string &table_id, uint64_t k, uint64_t nq, const float *vectors, QueryResults &results) override;
Query(const std::string &table_id,
uint64_t k,
uint64_t nq,
uint64_t nprobe,
const float *vectors,
QueryResults &results) override;
Status
Query(const std::string &table_id,
uint64_t k,
uint64_t nq,
uint64_t nprobe,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results) override;
......@@ -76,6 +82,7 @@ class DBImpl : public DB {
const std::vector<std::string> &file_ids,
uint64_t k,
uint64_t nq,
uint64_t nprobe,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results) override;
......@@ -94,6 +101,7 @@ class DBImpl : public DB {
const meta::TableFilesSchema &files,
uint64_t k,
uint64_t nq,
uint64_t nprobe,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results);
......
......@@ -13,10 +13,11 @@ namespace zilliz {
namespace milvus {
namespace engine {
SearchContext::SearchContext(uint64_t topk, uint64_t nq, const float* vectors)
SearchContext::SearchContext(uint64_t topk, uint64_t nq, uint64_t nprobe, const float* vectors)
: IScheduleContext(ScheduleContextType::kSearch),
topk_(topk),
nq_(nq),
nprobe_(nprobe),
vectors_(vectors) {
//use current time to identify this context
std::chrono::system_clock::time_point tp = std::chrono::system_clock::now();
......
......@@ -21,7 +21,7 @@ using TableFileSchemaPtr = std::shared_ptr<meta::TableFileSchema>;
class SearchContext : public IScheduleContext {
public:
SearchContext(uint64_t topk, uint64_t nq, const float* vectors);
SearchContext(uint64_t topk, uint64_t nq, uint64_t nprobe, const float* vectors);
bool AddIndexFile(TableFileSchemaPtr& index_file);
......@@ -53,6 +53,7 @@ public:
private:
uint64_t topk_ = 0;
uint64_t nq_ = 0;
uint64_t nprobe_ = 0;
const float* vectors_ = nullptr;
Id2IndexMap map_index_files_;
......
......@@ -174,7 +174,7 @@ namespace {
std::vector<TopKQueryResult> topk_query_result_array;
{
TimeRecorder rc(phase_name);
Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, topk_query_result_array);
Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 0, topk_query_result_array);
std::cout << "SearchVector function call status: " << stat.ToString() << std::endl;
}
......
......@@ -210,12 +210,14 @@ ClientProxy::Search(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) {
try {
//step 1: convert vectors data
::milvus::grpc::SearchParam search_param;
search_param.set_table_name(table_name);
search_param.set_topk(topk);
search_param.set_nprobe(nprobe);
for (auto &record : query_record_array) {
::milvus::grpc::RowRecord *row_record = search_param.add_query_record_array();
for (auto &rec : record.data) {
......
......@@ -47,6 +47,7 @@ public:
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) override;
virtual Status
......
......@@ -247,6 +247,7 @@ class Connection {
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) = 0;
/**
......
......@@ -83,9 +83,10 @@ ConnectionImpl::Search(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) {
return client_proxy_->Search(table_name, query_record_array, query_range_array, topk,
topk_query_result_array);
nprobe, topk_query_result_array);
}
Status
......
......@@ -53,6 +53,7 @@ public:
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) override;
virtual Status
......
......@@ -510,12 +510,17 @@ SearchTask::OnExecute() {
return SetError(res, "Invalid table name: " + table_name_);
}
int top_k_ = search_param_.topk();
int64_t top_k_ = search_param_.topk();
if (top_k_ <= 0 || top_k_ > 1024) {
return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(
top_k_));
return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(top_k_));
}
int64_t nprobe = search_param_.nprobe();
if (nprobe <= 0) {
return SetError(SERVER_INVALID_NPROBE, "Invalid nprobe: " + std::to_string(nprobe));
}
if (search_param_.query_record_array().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
}
......@@ -584,11 +589,11 @@ SearchTask::OnExecute() {
auto record_count = (uint64_t) search_param_.query_record_array().size();
if (file_id_array_.empty()) {
stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, vec_f.data(),
stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, nprobe, vec_f.data(),
dates, results);
} else {
stat = DBWrapper::DB()->Query(table_name_, file_id_array_,
(size_t) top_k_, record_count, vec_f.data(), dates, results);
stat = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t) top_k_,
record_count, nprobe, vec_f.data(), dates, results);
}
rc.ElapseFromBegin("search vectors from engine");
......
......@@ -50,6 +50,8 @@ constexpr ServerError SERVER_ILLEGAL_VECTOR_ID = ToGlobalServerErrorCode(109);
constexpr ServerError SERVER_ILLEGAL_SEARCH_RESULT = ToGlobalServerErrorCode(110);
constexpr ServerError SERVER_CACHE_ERROR = ToGlobalServerErrorCode(111);
constexpr ServerError SERVER_WRITE_ERROR = ToGlobalServerErrorCode(112);
constexpr ServerError SERVER_INVALID_NPROBE = ToGlobalServerErrorCode(113);
constexpr ServerError SERVER_LICENSE_FILE_NOT_EXIST = ToGlobalServerErrorCode(500);
constexpr ServerError SERVER_LICENSE_VALIDATION_FAIL = ToGlobalServerErrorCode(501);
......
......@@ -226,6 +226,39 @@ TEST_F(DBTest, SEARCH_TEST) {
// TODO(linxj): add groundTruth assert
};
TEST_F(DBTest, PRELOADTABLE_TEST) {
engine::meta::TableSchema table_info = BuildTableSchema();
engine::Status stat = db_->CreateTable(table_info);
engine::meta::TableSchema table_info_get;
table_info_get.table_id_ = TABLE_NAME;
stat = db_->DescribeTable(table_info_get);
ASSERT_STATS(stat);
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
engine::IDNumbers vector_ids;
engine::IDNumbers target_ids;
int64_t nb = 50;
std::vector<float> xb;
BuildVectors(nb, xb);
int loop = INSERT_LOOP;
for (auto i=0; i<loop; ++i) {
db_->InsertVectors(TABLE_NAME, qb, qxb.data(), target_ids);
ASSERT_EQ(target_ids.size(), qb);
}
int64_t prev_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage();
stat = db_->PreloadTable(TABLE_NAME);
ASSERT_STATS(stat);
int64_t cur_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage();
ASSERT_TRUE(prev_cache_usage < cur_cache_usage);
}
TEST_F(DBTest2, ARHIVE_DISK_CHECK) {
engine::meta::TableSchema table_info = BuildTableSchema();
......@@ -309,4 +342,4 @@ TEST_F(DBTest2, DELETE_TEST) {
db_->HasTable(TABLE_NAME, has_table);
ASSERT_FALSE(has_table);
};
};
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册