提交 7cf1917a 编写于 作者: P peng.xu

Merge branch 'branch-0.4.0' into 'branch-0.4.0'

MS-432 Search vectors params nprobe need to check max number

See merge request megasearch/milvus!444

Former-commit-id: e85b496ac6786776b246c4903ad8e0f3513f9af8
...@@ -11,6 +11,7 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -11,6 +11,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-417 - YAML sequence load disable cause scheduler startup failed - MS-417 - YAML sequence load disable cause scheduler startup failed
- MS-413 - Create index failed and server exited - MS-413 - Create index failed and server exited
- MS-427 - Describe index error after drop index - MS-427 - Describe index error after drop index
- MS-432 - Search vectors params nprobe need to check max number
## Improvement ## Improvement
- MS-327 - Clean code for milvus - MS-327 - Clean code for milvus
......
...@@ -543,17 +543,17 @@ Status DBImpl::CreateIndex(const std::string& table_id, const TableIndex& index) ...@@ -543,17 +543,17 @@ Status DBImpl::CreateIndex(const std::string& table_id, const TableIndex& index)
//step 2: drop old index files //step 2: drop old index files
DropIndex(table_id); DropIndex(table_id);
if(index.engine_type_ == (int)EngineType::FAISS_IDMAP) {
ENGINE_LOG_DEBUG << "index type = IDMAP, no need to build index";
return Status::OK();
}
//step 3: update index info //step 3: update index info
status = meta_ptr_->UpdateTableIndexParam(table_id, index); status = meta_ptr_->UpdateTableIndexParam(table_id, index);
if (!status.ok()) { if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to update table index info"; ENGINE_LOG_ERROR << "Failed to update table index info";
return status; return status;
} }
if(index.engine_type_ == (int)EngineType::FAISS_IDMAP) {
ENGINE_LOG_DEBUG << "index type = IDMAP, no need to build index";
return Status::OK();
}
} }
bool has = false; bool has = false;
......
...@@ -573,28 +573,13 @@ SearchTask::OnExecute() { ...@@ -573,28 +573,13 @@ SearchTask::OnExecute() {
try { try {
TimeRecorder rc("SearchTask"); TimeRecorder rc("SearchTask");
//step 1: check arguments //step 1: check table name
std::string table_name_ = search_param_->table_name(); std::string table_name_ = search_param_->table_name();
ServerError res = ValidationUtil::ValidateTableName(table_name_); ServerError res = ValidationUtil::ValidateTableName(table_name_);
if (res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name_); return SetError(res, "Invalid table name: " + table_name_);
} }
int64_t top_k_ = search_param_->topk();
if (top_k_ <= 0 || top_k_ > 1024) {
return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(top_k_));
}
int64_t nprobe = search_param_->nprobe();
if (nprobe <= 0) {
return SetError(SERVER_INVALID_NPROBE, "Invalid nprobe: " + std::to_string(nprobe));
}
if (search_param_->query_record_array().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
}
//step 2: check table existence //step 2: check table existence
engine::meta::TableSchema table_info; engine::meta::TableSchema table_info;
table_info.table_id_ = table_name_; table_info.table_id_ = table_name_;
...@@ -607,7 +592,24 @@ SearchTask::OnExecute() { ...@@ -607,7 +592,24 @@ SearchTask::OnExecute() {
} }
} }
//step 3: check date range, and convert to db dates //step 3: check search parameter
int64_t top_k = search_param_->topk();
res = ValidationUtil::ValidateSearchTopk(top_k, table_info);
if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid topk: " + std::to_string(top_k));
}
int64_t nprobe = search_param_->nprobe();
res = ValidationUtil::ValidateSearchNprobe(nprobe, table_info);
if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid nprobe: " + std::to_string(nprobe));
}
if (search_param_->query_record_array().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
}
//step 4: check date range, and convert to db dates
std::vector<DB_DATE> dates; std::vector<DB_DATE> dates;
ServerError error_code = SERVER_SUCCESS; ServerError error_code = SERVER_SUCCESS;
std::string error_msg; std::string error_msg;
...@@ -630,7 +632,7 @@ SearchTask::OnExecute() { ...@@ -630,7 +632,7 @@ SearchTask::OnExecute() {
ProfilerStart(fname.c_str()); ProfilerStart(fname.c_str());
#endif #endif
//step 3: prepare float data //step 5: prepare float data
auto record_array_size = search_param_->query_record_array_size(); auto record_array_size = search_param_->query_record_array_size();
std::vector<float> vec_f(record_array_size * table_info.dimension_, 0); std::vector<float> vec_f(record_array_size * table_info.dimension_, 0);
for (size_t i = 0; i < record_array_size; i++) { for (size_t i = 0; i < record_array_size; i++) {
...@@ -651,15 +653,15 @@ SearchTask::OnExecute() { ...@@ -651,15 +653,15 @@ SearchTask::OnExecute() {
} }
rc.ElapseFromBegin("prepare vector data"); rc.ElapseFromBegin("prepare vector data");
//step 4: search vectors //step 6: search vectors
engine::QueryResults results; engine::QueryResults results;
auto record_count = (uint64_t) search_param_->query_record_array().size(); auto record_count = (uint64_t) search_param_->query_record_array().size();
if (file_id_array_.empty()) { if (file_id_array_.empty()) {
stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, nprobe, vec_f.data(), stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k, record_count, nprobe, vec_f.data(),
dates, results); dates, results);
} else { } else {
stat = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t) top_k_, stat = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t) top_k,
record_count, nprobe, vec_f.data(), dates, results); record_count, nprobe, vec_f.data(), dates, results);
} }
...@@ -680,7 +682,7 @@ SearchTask::OnExecute() { ...@@ -680,7 +682,7 @@ SearchTask::OnExecute() {
rc.ElapseFromBegin("do search"); rc.ElapseFromBegin("do search");
//step 5: construct result array //step 7: construct result array
for (uint64_t i = 0; i < record_count; i++) { for (uint64_t i = 0; i < record_count; i++) {
auto &result = results[i]; auto &result = results[i];
const auto &record = search_param_->query_record_array(i); const auto &record = search_param_->query_record_array(i);
...@@ -699,10 +701,10 @@ SearchTask::OnExecute() { ...@@ -699,10 +701,10 @@ SearchTask::OnExecute() {
ProfilerStop(); ProfilerStop();
#endif #endif
//step 8: print time cost percent
double span_result = rc.RecordSection("construct result"); double span_result = rc.RecordSection("construct result");
rc.ElapseFromBegin("totally cost"); rc.ElapseFromBegin("totally cost");
//step 6: print time cost percent
} catch (std::exception &ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
......
...@@ -92,6 +92,24 @@ ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) { ...@@ -92,6 +92,24 @@ ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) {
return SERVER_SUCCESS; return SERVER_SUCCESS;
} }
ServerError
ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema) {
if (top_k <= 0 || top_k > 1024) {
return SERVER_INVALID_TOPK;
}
return SERVER_SUCCESS;
}
ServerError
ValidationUtil::ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema) {
if (nprobe <= 0 || nprobe > table_schema.nlist_) {
return SERVER_INVALID_NPROBE;
}
return SERVER_SUCCESS;
}
ServerError ServerError
ValidationUtil::ValidateGpuIndex(uint32_t gpu_index) { ValidationUtil::ValidateGpuIndex(uint32_t gpu_index) {
int num_devices = 0; int num_devices = 0;
......
#pragma once #pragma once
#include "db/meta/MetaTypes.h"
#include "Error.h" #include "Error.h"
namespace zilliz { namespace zilliz {
...@@ -26,6 +27,12 @@ public: ...@@ -26,6 +27,12 @@ public:
static ServerError static ServerError
ValidateTableIndexMetricType(int32_t metric_type); ValidateTableIndexMetricType(int32_t metric_type);
static ServerError
ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema);
static ServerError
ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema);
static ServerError static ServerError
ValidateGpuIndex(uint32_t gpu_index); ValidateGpuIndex(uint32_t gpu_index);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册