diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index 5e79f04137b17af03a92a188e7189945b49896a5..83c69d7e2e4f51b51349c7127184262fcff18229 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -11,6 +11,7 @@ Please mark all change in change log and use the ticket from JIRA. - MS-417 - YAML sequence load disable cause scheduler startup failed - MS-413 - Create index failed and server exited - MS-427 - Describe index error after drop index +- MS-432 - Search vectors params nprobe need to check max number ## Improvement - MS-327 - Clean code for milvus diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index f6e340f9be15c9833274764528941924f8b4cd7d..168522e21d0287a35d636af06f197881d2c5fd31 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -543,17 +543,17 @@ Status DBImpl::CreateIndex(const std::string& table_id, const TableIndex& index) //step 2: drop old index files DropIndex(table_id); - if(index.engine_type_ == (int)EngineType::FAISS_IDMAP) { - ENGINE_LOG_DEBUG << "index type = IDMAP, no need to build index"; - return Status::OK(); - } - //step 3: update index info status = meta_ptr_->UpdateTableIndexParam(table_id, index); if (!status.ok()) { ENGINE_LOG_ERROR << "Failed to update table index info"; return status; } + + if(index.engine_type_ == (int)EngineType::FAISS_IDMAP) { + ENGINE_LOG_DEBUG << "index type = IDMAP, no need to build index"; + return Status::OK(); + } } bool has = false; diff --git a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp index a4b8d68c296e4066f263ecf41f148e18845e67bc..c559d614acc2ac27ce10d5ee1fa20f78cdfb8d20 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp +++ b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp @@ -573,28 +573,13 @@ SearchTask::OnExecute() { try { TimeRecorder rc("SearchTask"); - //step 1: check arguments + //step 1: check table name std::string table_name_ = search_param_->table_name(); ServerError res = ValidationUtil::ValidateTableName(table_name_); if (res != SERVER_SUCCESS) { return SetError(res, "Invalid table name: " + table_name_); } - int64_t top_k_ = search_param_->topk(); - - if (top_k_ <= 0 || top_k_ > 1024) { - return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(top_k_)); - } - - int64_t nprobe = search_param_->nprobe(); - if (nprobe <= 0) { - return SetError(SERVER_INVALID_NPROBE, "Invalid nprobe: " + std::to_string(nprobe)); - } - - if (search_param_->query_record_array().empty()) { - return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); - } - //step 2: check table existence engine::meta::TableSchema table_info; table_info.table_id_ = table_name_; @@ -607,7 +592,24 @@ SearchTask::OnExecute() { } } - //step 3: check date range, and convert to db dates + //step 3: check search parameter + int64_t top_k = search_param_->topk(); + res = ValidationUtil::ValidateSearchTopk(top_k, table_info); + if (res != SERVER_SUCCESS) { + return SetError(res, "Invalid topk: " + std::to_string(top_k)); + } + + int64_t nprobe = search_param_->nprobe(); + res = ValidationUtil::ValidateSearchNprobe(nprobe, table_info); + if (res != SERVER_SUCCESS) { + return SetError(res, "Invalid nprobe: " + std::to_string(nprobe)); + } + + if (search_param_->query_record_array().empty()) { + return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); + } + + //step 4: check date range, and convert to db dates std::vector dates; ServerError error_code = SERVER_SUCCESS; std::string error_msg; @@ -630,7 +632,7 @@ SearchTask::OnExecute() { ProfilerStart(fname.c_str()); #endif - //step 3: prepare float data + //step 5: prepare float data auto record_array_size = search_param_->query_record_array_size(); std::vector vec_f(record_array_size * table_info.dimension_, 0); for (size_t i = 0; i < record_array_size; i++) { @@ -651,15 +653,15 @@ SearchTask::OnExecute() { } rc.ElapseFromBegin("prepare vector data"); - //step 4: search vectors + //step 6: search vectors engine::QueryResults results; auto record_count = (uint64_t) search_param_->query_record_array().size(); if (file_id_array_.empty()) { - stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, nprobe, vec_f.data(), + stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k, record_count, nprobe, vec_f.data(), dates, results); } else { - stat = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t) top_k_, + stat = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t) top_k, record_count, nprobe, vec_f.data(), dates, results); } @@ -680,7 +682,7 @@ SearchTask::OnExecute() { rc.ElapseFromBegin("do search"); - //step 5: construct result array + //step 7: construct result array for (uint64_t i = 0; i < record_count; i++) { auto &result = results[i]; const auto &record = search_param_->query_record_array(i); @@ -699,10 +701,10 @@ SearchTask::OnExecute() { ProfilerStop(); #endif + //step 8: print time cost percent double span_result = rc.RecordSection("construct result"); rc.ElapseFromBegin("totally cost"); - //step 6: print time cost percent } catch (std::exception &ex) { return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); diff --git a/cpp/src/utils/ValidationUtil.cpp b/cpp/src/utils/ValidationUtil.cpp index 0e2e1c2b4f0dfffeb88cb94ac7704930b14859ac..2245496903aca085b8d28bfc10afe93c902c7506 100644 --- a/cpp/src/utils/ValidationUtil.cpp +++ b/cpp/src/utils/ValidationUtil.cpp @@ -92,6 +92,24 @@ ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) { return SERVER_SUCCESS; } +ServerError +ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema) { + if (top_k <= 0 || top_k > 1024) { + return SERVER_INVALID_TOPK; + } + + return SERVER_SUCCESS; +} + +ServerError +ValidationUtil::ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema) { + if (nprobe <= 0 || nprobe > table_schema.nlist_) { + return SERVER_INVALID_NPROBE; + } + + return SERVER_SUCCESS; +} + ServerError ValidationUtil::ValidateGpuIndex(uint32_t gpu_index) { int num_devices = 0; diff --git a/cpp/src/utils/ValidationUtil.h b/cpp/src/utils/ValidationUtil.h index 4792500f6719089e49b447f65b1b84f95f1d4a6a..d1d84f710465a6111d093c8953dd95cb8c08256a 100644 --- a/cpp/src/utils/ValidationUtil.h +++ b/cpp/src/utils/ValidationUtil.h @@ -1,5 +1,6 @@ #pragma once +#include "db/meta/MetaTypes.h" #include "Error.h" namespace zilliz { @@ -26,6 +27,12 @@ public: static ServerError ValidateTableIndexMetricType(int32_t metric_type); + static ServerError + ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema); + + static ServerError + ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema); + static ServerError ValidateGpuIndex(uint32_t gpu_index);