提交 ac7403f0 编写于 作者: G groot

search combine configable

Signed-off-by: Ngroot <yihua.mo@zilliz.com>
上级 69eb1d25
......@@ -114,6 +114,8 @@ const char* CONFIG_ENGINE_OMP_THREAD_NUM = "omp_thread_num";
const char* CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT = "0";
const char* CONFIG_ENGINE_SIMD_TYPE = "simd_type";
const char* CONFIG_ENGINE_SIMD_TYPE_DEFAULT = "auto";
const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ = "search_combine_nq";
const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT = "64";
/* gpu resource config */
const char* CONFIG_GPU_RESOURCE = "gpu";
......@@ -200,6 +202,9 @@ Config::Config() {
std::string node_blas_threshold = std::string(CONFIG_ENGINE) + "." + CONFIG_ENGINE_USE_BLAS_THRESHOLD;
config_callback_[node_blas_threshold] = empty_map;
std::string node_search_combine = std::string(CONFIG_ENGINE) + "." + CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ;
config_callback_[node_search_combine] = empty_map;
// gpu resources config
std::string node_gpu_enable = std::string(CONFIG_GPU_RESOURCE) + "." + CONFIG_GPU_RESOURCE_ENABLE;
config_callback_[node_gpu_enable] = empty_map;
......@@ -477,6 +482,7 @@ Config::ResetDefaultConfig() {
STATUS_CHECK(SetEngineConfigUseBlasThreshold(CONFIG_ENGINE_USE_BLAS_THRESHOLD_DEFAULT));
STATUS_CHECK(SetEngineConfigOmpThreadNum(CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT));
STATUS_CHECK(SetEngineConfigSimdType(CONFIG_ENGINE_SIMD_TYPE_DEFAULT));
STATUS_CHECK(SetEngineSearchCombineMaxNq(CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT));
/* gpu resource config */
#ifdef MILVUS_GPU_VERSION
......@@ -613,6 +619,8 @@ Config::SetConfigCli(const std::string& parent_key, const std::string& child_key
status = SetEngineConfigOmpThreadNum(value);
} else if (child_key == CONFIG_ENGINE_SIMD_TYPE) {
status = SetEngineConfigSimdType(value);
} else if (child_key == CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ) {
status = SetEngineSearchCombineMaxNq(value);
} else {
status = Status(SERVER_UNEXPECTED_ERROR, invalid_node_str);
}
......@@ -1552,6 +1560,18 @@ Config::CheckEngineConfigSimdType(const std::string& value) {
return Status::OK();
}
Status
Config::CheckEngineSearchCombineMaxNq(const std::string& value) {
fiu_return_on("check_config_search_combine_nq_fail", Status(SERVER_INVALID_ARGUMENT, ""));
if (!ValidationUtil::ValidateStringIsNumber(value).ok()) {
std::string msg = "Invalid omp thread num: " + value +
". Possible reason: engine_config.omp_thread_num is not a positive integer.";
return Status(SERVER_INVALID_ARGUMENT, msg);
}
return Status::OK();
}
#ifdef MILVUS_GPU_VERSION
/* gpu resource config */
......@@ -2247,6 +2267,15 @@ Config::GetEngineConfigSimdType(std::string& value) {
return CheckEngineConfigSimdType(value);
}
Status
Config::GetEngineSearchCombineMaxNq(int64_t& value) {
std::string str =
GetConfigStr(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT);
// STATUS_CHECK(CheckEngineSearchCombineMaxNq(str));
value = std::stoll(str);
return Status::OK();
}
/* gpu resource config */
#ifdef MILVUS_GPU_VERSION
......@@ -2456,6 +2485,7 @@ Config::SetClusterConfigEnable(const std::string& value) {
STATUS_CHECK(CheckClusterConfigEnable(value));
return SetConfigValueInMem(CONFIG_CLUSTER, CONFIG_CLUSTER_ENABLE, value);
}
Status
Config::SetClusterConfigRole(const std::string& value) {
STATUS_CHECK(CheckClusterConfigRole(value));
......@@ -2685,8 +2715,16 @@ Config::SetEngineConfigSimdType(const std::string& value) {
return SetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_SIMD_TYPE, value);
}
Status
Config::SetEngineSearchCombineMaxNq(const std::string& value) {
STATUS_CHECK(CheckEngineSearchCombineMaxNq(value));
STATUS_CHECK(SetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, value));
return ExecCallBacks(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, value);
}
/* gpu resource config */
#ifdef MILVUS_GPU_VERSION
Status
Config::SetGpuResourceConfigEnable(const std::string& value) {
STATUS_CHECK(CheckGpuResourceConfigEnable(value));
......@@ -2731,6 +2769,7 @@ Config::SetGpuResourceConfigBuildIndexResources(const std::string& value) {
STATUS_CHECK(SetConfigValueInMem(CONFIG_GPU_RESOURCE, CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES, value));
return ExecCallBacks(CONFIG_GPU_RESOURCE, CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES, value);
}
#endif
/* tracing config */
......
......@@ -102,6 +102,8 @@ extern const char* CONFIG_ENGINE_OMP_THREAD_NUM;
extern const char* CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT;
extern const char* CONFIG_ENGINE_SIMD_TYPE;
extern const char* CONFIG_ENGINE_SIMD_TYPE_DEFAULT;
extern const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ;
extern const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT;
/* gpu resource config */
extern const char* CONFIG_GPU_RESOURCE;
......@@ -268,6 +270,8 @@ class Config {
CheckEngineConfigOmpThreadNum(const std::string& value);
Status
CheckEngineConfigSimdType(const std::string& value);
Status
CheckEngineSearchCombineMaxNq(const std::string& value);
#ifdef MILVUS_GPU_VERSION
/* gpu resource config */
......@@ -388,6 +392,8 @@ class Config {
GetEngineConfigOmpThreadNum(int64_t& value);
Status
GetEngineConfigSimdType(std::string& value);
Status
GetEngineSearchCombineMaxNq(int64_t& value);
#ifdef MILVUS_GPU_VERSION
/* gpu resource config */
......@@ -500,6 +506,8 @@ class Config {
SetEngineConfigOmpThreadNum(const std::string& value);
Status
SetEngineConfigSimdType(const std::string& value);
Status
SetEngineSearchCombineMaxNq(const std::string& value);
#ifdef MILVUS_GPU_VERSION
/* gpu resource config */
......
......@@ -19,10 +19,12 @@ namespace server {
EngineConfigHandler::EngineConfigHandler() {
auto& config = Config::GetInstance();
config.GetEngineConfigUseBlasThreshold(use_blas_threshold_);
config.GetEngineSearchCombineMaxNq(search_combine_nq_);
}
EngineConfigHandler::~EngineConfigHandler() {
RemoveUseBlasThresholdListener();
RemoveSearchCombineMaxNqListener();
}
//////////////////////////// Listener methods //////////////////////////////////
......@@ -48,5 +50,27 @@ EngineConfigHandler::RemoveUseBlasThresholdListener() {
config.CancelCallBack(CONFIG_ENGINE, CONFIG_ENGINE_USE_BLAS_THRESHOLD, identity_);
}
void
EngineConfigHandler::AddSearchCombineMaxNqListener() {
ConfigCallBackF lambda = [this](const std::string& value) -> Status {
auto& config = server::Config::GetInstance();
auto status = config.GetEngineSearchCombineMaxNq(search_combine_nq_);
if (status.ok()) {
OnSearchCombineMaxNqChanged(search_combine_nq_);
}
return status;
};
auto& config = Config::GetInstance();
config.RegisterCallBack(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, identity_, lambda);
}
void
EngineConfigHandler::RemoveSearchCombineMaxNqListener() {
auto& config = Config::GetInstance();
config.CancelCallBack(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, identity_);
}
} // namespace server
} // namespace milvus
......@@ -28,16 +28,27 @@ class EngineConfigHandler : virtual public ConfigHandler {
OnUseBlasThresholdChanged(int64_t threshold) {
}
virtual void
OnSearchCombineMaxNqChanged(int64_t nq) {
search_combine_nq_ = nq;
}
protected:
void
AddUseBlasThresholdListener();
protected:
void
RemoveUseBlasThresholdListener();
void
AddSearchCombineMaxNqListener();
void
RemoveSearchCombineMaxNqListener();
protected:
int64_t use_blas_threshold_ = std::stoll(CONFIG_ENGINE_USE_BLAS_THRESHOLD_DEFAULT);
int64_t search_combine_nq_ = std::stoll(CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT);
};
} // namespace server
......
......@@ -27,7 +27,6 @@ namespace server {
namespace {
constexpr int64_t MAX_TOPK_GAP = 200;
constexpr uint64_t MAX_NQ = 200;
void
GetUniqueList(const std::vector<std::string>& list, std::set<std::string>& unique_list) {
......@@ -93,7 +92,8 @@ class TracingContextList {
} // namespace
SearchCombineRequest::SearchCombineRequest() : BaseRequest(nullptr, BaseRequest::kSearchCombine) {
SearchCombineRequest::SearchCombineRequest(int64_t max_nq)
: BaseRequest(nullptr, BaseRequest::kSearchCombine), combine_max_nq_(max_nq) {
}
Status
......@@ -133,6 +133,8 @@ SearchCombineRequest::Combine(const SearchRequestPtr& request) {
}
request_list_.push_back(request);
vectors_data_.vector_count_ += request->VectorsData().vector_count_;
return Status::OK();
}
......@@ -152,11 +154,11 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& request) {
}
// sum of nq must less-equal than MAX_NQ
if (vectors_data_.vector_count_ > MAX_NQ || request->VectorsData().vector_count_ > MAX_NQ) {
if (vectors_data_.vector_count_ > combine_max_nq_ || request->VectorsData().vector_count_ > combine_max_nq_) {
return false;
}
uint64_t total_nq = vectors_data_.vector_count_ + request->VectorsData().vector_count_;
if (total_nq > MAX_NQ) {
if (total_nq > combine_max_nq_) {
return false;
}
......@@ -178,7 +180,7 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& request) {
}
bool
SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right) {
SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right, int64_t max_nq) {
if (left->CollectionName() != right->CollectionName()) {
return false;
}
......@@ -193,11 +195,11 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchReque
}
// sum of nq must less-equal than MAX_NQ
if (left->VectorsData().vector_count_ > MAX_NQ || right->VectorsData().vector_count_ > MAX_NQ) {
if (left->VectorsData().vector_count_ > max_nq || right->VectorsData().vector_count_ > max_nq) {
return false;
}
uint64_t total_nq = left->VectorsData().vector_count_ + right->VectorsData().vector_count_;
if (total_nq > MAX_NQ) {
if (total_nq > max_nq) {
return false;
}
......
......@@ -22,9 +22,11 @@
namespace milvus {
namespace server {
constexpr int64_t COMBINE_MAX_NQ = 64;
class SearchCombineRequest : public BaseRequest {
public:
SearchCombineRequest();
SearchCombineRequest(int64_t max_nq = COMBINE_MAX_NQ);
Status
Combine(const SearchRequestPtr& request);
......@@ -33,7 +35,7 @@ class SearchCombineRequest : public BaseRequest {
CanCombine(const SearchRequestPtr& request);
static bool
CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right);
CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right, int64_t max_nq = COMBINE_MAX_NQ);
protected:
Status
......@@ -54,6 +56,8 @@ class SearchCombineRequest : public BaseRequest {
std::set<std::string> file_id_list_;
std::vector<SearchRequestPtr> request_list_;
int64_t combine_max_nq_ = COMBINE_MAX_NQ;
};
using SearchCombineRequestPtr = std::shared_ptr<SearchCombineRequest>;
......
......@@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "server/delivery/strategy/SearchReqStrategy.h"
#include "config/Config.h"
#include "server/delivery/request/SearchCombineRequest.h"
#include "server/delivery/request/SearchRequest.h"
#include "utils/CommonUtil.h"
......@@ -24,6 +25,8 @@ namespace milvus {
namespace server {
SearchReqStrategy::SearchReqStrategy() {
SetIdentity("SearchReqStrategy");
AddSearchCombineMaxNqListener();
}
Status
......@@ -34,15 +37,21 @@ SearchReqStrategy::ReScheduleQueue(const BaseRequestPtr& request, std::queue<Bas
return Status(SERVER_UNSUPPORTED_ERROR, msg);
}
// if config set to 0, neve combine
if (search_combine_nq_ <= 0) {
queue.push(request);
return Status::OK();
}
// TimeRecorderAuto rc("SearchReqStrategy::ReScheduleQueue");
SearchRequestPtr new_search_req = std::static_pointer_cast<SearchRequest>(request);
BaseRequestPtr last_req = queue.back();
if (last_req->GetRequestType() == BaseRequest::kSearch) {
SearchRequestPtr last_search_req = std::static_pointer_cast<SearchRequest>(last_req);
if (SearchCombineRequest::CanCombine(last_search_req, new_search_req)) {
if (SearchCombineRequest::CanCombine(last_search_req, new_search_req, search_combine_nq_)) {
// combine request
SearchCombineRequestPtr combine_request = std::make_shared<SearchCombineRequest>();
SearchCombineRequestPtr combine_request = std::make_shared<SearchCombineRequest>(search_combine_nq_);
combine_request->Combine(last_search_req);
combine_request->Combine(new_search_req);
queue.back() = combine_request; // replace the last request to combine request
......
......@@ -11,6 +11,7 @@
#pragma once
#include "config/handler/EngineConfigHandler.h"
#include "server/delivery/strategy/RequestStrategy.h"
#include "utils/Status.h"
......@@ -20,7 +21,7 @@
namespace milvus {
namespace server {
class SearchReqStrategy : public RequestStrategy {
class SearchReqStrategy : public RequestStrategy, public EngineConfigHandler {
public:
SearchReqStrategy();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册