diff --git a/predictor/common/inner_common.h b/predictor/common/inner_common.h index 43ef30f0fd3fe335e23362abfbc6cb97a42124bd..b1f4ec8ba6a8d79db5fbdf2e491804f6d39b052a 100644 --- a/predictor/common/inner_common.h +++ b/predictor/common/inner_common.h @@ -9,6 +9,8 @@ #include #include +#include + #include #include #include diff --git a/predictor/framework/dag.cpp b/predictor/framework/dag.cpp index c0ec7f97fa1a465b4365da815b5f7b2d31be5d49..cbfd0ea60ae856ce42a332aa698a626c30c5f4ab 100644 --- a/predictor/framework/dag.cpp +++ b/predictor/framework/dag.cpp @@ -85,6 +85,7 @@ EdgeMode Dag::parse_mode(std::string& mode) { // [.@Depend] // name: dnn_inference // mode: RO +#if 0 int Dag::init(const char* path, const char* file, const std::string& name) { comcfg::Configure conf; if (conf.load(path, file) != 0) { @@ -96,26 +97,27 @@ int Dag::init(const char* path, const char* file, const std::string& name) { return init(conf, name); } +#endif -int Dag::init(const comcfg::Configure& conf, const std::string& name) { +int Dag::init(const configure::Workflow& conf, const std::string& name) { _dag_name = name; _index_nodes.clear(); _name_nodes.clear(); - for (uint32_t i = 0; i < conf["Node"].size(); i++) { + for (uint32_t i = 0; i < conf.nodes_size(); i++) { DagNode* node = new (std::nothrow) DagNode(); if (node == NULL) { LOG(ERROR) << "Failed create new dag node"; return ERR_MEM_ALLOC_FAILURE; } node->id = i + 1; // 0 is reserved for begginer-op - node->name = conf["Node"][i]["name"].to_cstr(); - node->type = conf["Node"][i]["type"].to_cstr(); - uint32_t depend_size = conf["Node"][i]["Depend"].size(); + node->name = conf.nodes(i).name(); + node->type = conf.nodes(i).type(); + uint32_t depend_size = conf.nodes(i).dependencies_size(); for (uint32_t j = 0; j < depend_size; j++) { - const comcfg::ConfigUnit& depend = - conf["Node"][i]["Depend"][j]; - std::string name = depend["name"].to_cstr(); - std::string mode = depend["mode"].to_cstr(); + const configure::DAGNodeDependency& depend = + conf.nodes(i).dependencies(j); + std::string name = depend.name(); + std::string mode = depend.mode(); node->depends.insert( std::make_pair(name, parse_mode(mode))); } @@ -125,7 +127,7 @@ int Dag::init(const comcfg::Configure& conf, const std::string& name) { return ERR_INTERNAL_FAILURE; } // node->conf could be NULL - node->conf = op->create_config(conf["Node"][i]); + node->conf = op->create_config(conf.nodes(i)); OpRepository::instance().return_op(node->type, op); _name_nodes.insert(std::make_pair(node->name, node)); _index_nodes.push_back(node); diff --git a/predictor/framework/dag.h b/predictor/framework/dag.h index 7897964b2b4fdf3f9052637a856c1fe19a7ea602..1373ee2ed408ab5006fcc66889b49d6b808b7981 100644 --- a/predictor/framework/dag.h +++ b/predictor/framework/dag.h @@ -39,7 +39,7 @@ public: int init(const char* path, const char* file, const std::string& name); - int init(const comcfg::Configure& conf, const std::string& name); + int init(const configure::Workflow& conf, const std::string& name); int deinit(); diff --git a/predictor/framework/infer.h b/predictor/framework/infer.h index 5d41c144814b2a950ccff4bf6b7258fe7d398bd4..634864b9d833cf337ccefe181098a044abfac590 100644 --- a/predictor/framework/infer.h +++ b/predictor/framework/infer.h @@ -181,119 +181,10 @@ public: private: int parse_version_info(const configure::EngineDesc& config, bool version) { - try { - std::string version_file = config.version_file(); - std::string version_type = config.version_type(); - - if (version_type == "abacus_version") { - if (parse_abacus_version(version_file) != 0) { - LOG(FATAL) - << "Failed parse abacus version: " << version_file; - return -1; - } - } else if (version_type == "corece_uint64") { - if (parse_corece_uint64(version_file) != 0) { - LOG(FATAL) - << "Failed parse corece_uint64: " << version_file; - return -1; - } - } else { - LOG(FATAL) << "Not supported version_type: " << version_type; - return -1; - } - } catch (comcfg::ConfigException e) { // no version file - if (version) { - LOG(FATAL) << "Cannot parse version engine, err:" - << e.what(); - return -1; - } - - LOG(WARNING) << "Consistency with non-versioned configure"; - _version = uint64_t(-1); - } + _version = uint64_t(-1); return 0; } - int parse_abacus_version(const std::string& version_file) { - FILE* fp = fopen(version_file.c_str(), "r"); - if (!fp) { - LOG(FATAL) << "Failed open version file:" << version_file; - return -1; - } - - bool has_parsed = false; - char buffer[1024] = {0}; - while (fgets(buffer, sizeof(buffer), fp)) { - char* begin = NULL; - if (strncmp(buffer, "version:", 8) == 0 || - strncmp(buffer, "Version:", 8) == 0) { - begin = buffer + 8; - } else if (strncmp(buffer, "version :", 9) == 0 || - strncmp(buffer, "Version :", 9) == 0) { - begin = buffer + 9; - } else { - LOG(WARNING) << "Not version line: " << buffer; - continue; - } - - std::string vstr = begin; - boost::algorithm::trim_if( - vstr, boost::algorithm::is_any_of("\n\r ")); - char* endptr = NULL; - _version = strtoull(vstr.c_str(), &endptr, 10); - if (endptr == vstr.c_str()) { - LOG(FATAL) - << "Invalid version: [" << buffer << "], end: [" - << endptr << "]" << ", vstr: [" << vstr << "]"; - fclose(fp); - return -1; - } - has_parsed = true; - } - - if (!has_parsed) { - LOG(FATAL) << "Failed parse abacus version: " << version_file; - fclose(fp); - return -1; - } - - LOG(WARNING) << "Succ parse abacus version: " << _version - << " from: " << version_file; - fclose(fp); - return 0; - } - - int parse_corece_uint64(const std::string& version_file) { - FILE* fp = fopen(version_file.c_str(), "r"); - if (!fp) { - LOG(FATAL) << "Failed open version file:" << version_file; - return -1; - } - - bool has_parsed = false; - char buffer[1024] = {0}; - if (fgets(buffer, sizeof(buffer), fp)) { - char* endptr = NULL; - _version = strtoull(buffer, &endptr, 10); - if (endptr == buffer) { - LOG(FATAL) << "Invalid version: " << buffer; - fclose(fp); - return -1; - } - has_parsed = true; - } - - if (!has_parsed) { - LOG(FATAL) << "Failed parse abacus version: " << version_file; - fclose(fp); - return -1; - } - - LOG(WARNING) << "Succ parse corece version: " << _version - << " from: " << version_file; - fclose(fp); - return 0; - } bool check_need_reload() { if (_reload_mode_tag == "timestamp_ne") { @@ -756,23 +647,13 @@ public: } ~VersionedInferEngine() {} - int proc_initialize(const configure::VersionedEngine& conf) { - size_t version_num = conf.versions_size(); - for (size_t vi = 0; vi < version_num; ++vi) { - if (proc_initialize(conf.versions(vi), true) != 0) { - LOG(FATAL) << "Failed proc initialize version: " - << vi << ", model: " << conf.name().c_str(); - return -1; - } + int proc_initialize(const configure::EngineDesc& conf) { + if (proc_initialize(conf, false) != 0) { + LOG(FATAL) << "Failed proc intialize engine: " + << conf.name().c_str(); + return -1; } - if (version_num == 0) { - if (proc_initialize(conf.default_version(), false) != 0) { - LOG(FATAL) << "Failed proc intialize engine: " - << conf.name().c_str(); - return -1; - } - } LOG(WARNING) << "Succ proc initialize engine: " << conf.name().c_str(); return 0; diff --git a/predictor/framework/manager.h b/predictor/framework/manager.h index d7fa50f6a202653bb16931e317c45a1f9d163ca6..d1701c25d30fbce67734770b3c7d0e079264cb7d 100644 --- a/predictor/framework/manager.h +++ b/predictor/framework/manager.h @@ -11,6 +11,7 @@ namespace paddle_serving { namespace predictor { using configure::WorkflowConf; +using configure::InferServiceConf; class Workflow; //class InferService; @@ -30,32 +31,31 @@ inline InferService* create_item_impl() { } } -template -class Manager { +class WorkflowManager { public: - static Manager& instance() { - static Manager mgr; + static WorkflowManager& instance() { + static WorkflowManager mgr; return mgr; } int initialize(const std::string path, const std::string file) { WorkflowConf workflow_conf; if (configure::read_proto_conf(path, file, &workflow_conf) != 0) { - LOG(FATAL) << "Failed load manager<" << typeid.name() << "> configure!"; + LOG(FATAL) << "Failed load manager<" << Workflow::tag() << "> configure from " << path << "/" << file; return -1; } try { - uint32_t item_size = conf[T::tag()].size(); + uint32_t item_size = workflow_conf.workflows_size(); for (uint32_t ii = 0; ii < item_size; ii++) { - std::string name = conf[T::tag()][ii]["name"].to_cstr(); - T* item = new (std::nothrow) T(); + std::string name = workflow_conf.workflows(ii).name(); + Workflow* item = new (std::nothrow) Workflow(); if (item == NULL) { - LOG(FATAL) << "Failed create " << T::tag() << " for: " << name; + LOG(FATAL) << "Failed create " << Workflow::tag() << " for: " << name; return -1; } - if (item->init(conf[T::tag()][ii]) != 0) { + if (item->init(workflow_conf.workflows(ii)) != 0) { LOG(FATAL) << "Failed init item: " << name << " at:" << ii << "!"; @@ -63,7 +63,7 @@ public: } std::pair< - typename boost::unordered_map::iterator, bool> + typename boost::unordered_map::iterator, bool> r = _item_map.insert(std::make_pair(name, item)); if (!r.second) { LOG(FATAL) @@ -91,12 +91,12 @@ public: return 0; } - T* create_item() { - return create_item_impl(); + Workflow* create_item() { + return create_item_impl(); } - T* item(const std::string& name) { - typename boost::unordered_map::iterator it; + Workflow* item(const std::string& name) { + typename boost::unordered_map::iterator it; it = _item_map.find(name); if (it == _item_map.end()) { LOG(WARNING) << "Not found item: " << name << "!"; @@ -106,8 +106,8 @@ public: return it->second; } - T& operator[](const std::string& name) { - T* i = item(name); + Workflow& operator[](const std::string& name) { + Workflow* i = item(name); if (i == NULL) { std::string err = "Not found item in manager for:"; err += name; @@ -118,7 +118,7 @@ public: int reload() { int ret = 0; - typename boost::unordered_map::iterator it + typename boost::unordered_map::iterator it = _item_map.begin(); for (; it != _item_map.end(); ++it) { if (it->second->reload() != 0) { @@ -129,7 +129,7 @@ public: LOG(INFO) << "Finish reload " << _item_map.size() - << " " << T::tag() << "(s)"; + << " " << Workflow::tag() << "(s)"; return ret; } @@ -138,14 +138,124 @@ public: } private: - Manager() {} + WorkflowManager() {} private: - boost::unordered_map _item_map; + boost::unordered_map _item_map; }; -typedef Manager InferServiceManager; -typedef Manager WorkflowManager; +class InferServiceManager { +public: + static InferServiceManager& instance() { + static InferServiceManager mgr; + return mgr; + } + + int initialize(const std::string path, const std::string file) { + InferServiceConf infer_service_conf; + if (configure::read_proto_conf(path, file, &infer_service_conf) != 0) { + LOG(FATAL) << "Failed load manager<" << InferService::tag() << "> configure!"; + return -1; + } + + try { + + uint32_t item_size = infer_service_conf.services_size(); + for (uint32_t ii = 0; ii < item_size; ii++) { + std::string name = infer_service_conf.services(ii).name(); + InferService* item = new (std::nothrow) InferService(); + if (item == NULL) { + LOG(FATAL) << "Failed create " << InferService::tag() << " for: " << name; + return -1; + } + if (item->init(infer_service_conf.services(ii)) != 0) { + LOG(FATAL) + << "Failed init item: " << name << " at:" + << ii << "!"; + return -1; + } + + std::pair< + typename boost::unordered_map::iterator, bool> + r = _item_map.insert(std::make_pair(name, item)); + if (!r.second) { + LOG(FATAL) + << "Failed insert item:" << name << " at:" + << ii << "!"; + return -1; + } + + LOG(INFO) + << "Succ init item:" << name << " from conf:" + << path << "/" << file << ", at:" << ii << "!"; + } + + } catch (comcfg::ConfigException e) { + LOG(FATAL) + << "Config[" << path << "/" << file << "] format " + << "invalid, err: " << e.what(); + return -1; + } catch (...) { + LOG(FATAL) + << "Config[" << path << "/" << file << "] format " + << "invalid, load failed"; + return -1; + } + return 0; + } + + InferService* create_item() { + return create_item_impl(); + } + + InferService* item(const std::string& name) { + typename boost::unordered_map::iterator it; + it = _item_map.find(name); + if (it == _item_map.end()) { + LOG(WARNING) << "Not found item: " << name << "!"; + return NULL; + } + + return it->second; + } + + InferService& operator[](const std::string& name) { + InferService* i = item(name); + if (i == NULL) { + std::string err = "Not found item in manager for:"; + err += name; + throw std::overflow_error(err); + } + return *i; + } + + int reload() { + int ret = 0; + typename boost::unordered_map::iterator it + = _item_map.begin(); + for (; it != _item_map.end(); ++it) { + if (it->second->reload() != 0) { + LOG(WARNING) << "failed reload item: " << it->first << "!"; + ret = -1; + } + } + + LOG(INFO) << "Finish reload " + << _item_map.size() + << " " << InferService::tag() << "(s)"; + return ret; + } + + int finalize() { + return 0; + } + +private: + InferServiceManager() {} + +private: + boost::unordered_map _item_map; +}; } // predictor } // paddle_serving diff --git a/predictor/framework/service.cpp b/predictor/framework/service.cpp index 678ec062b50be12fe37cbb46e213556095ca819c..5ec5f1ae4829414bca4a366ab7b1d15bfe247dfe 100644 --- a/predictor/framework/service.cpp +++ b/predictor/framework/service.cpp @@ -13,10 +13,13 @@ namespace baidu { namespace paddle_serving { namespace predictor { -int InferService::init(const comcfg::ConfigUnit& conf) { - _infer_service_format = conf["name"].to_cstr(); - char merger[256]; - conf["merger"].get_cstr(merger, sizeof(merger), "default"); +int InferService::init(const configure::InferService& conf) { + _infer_service_format = conf.name(); + + std::string merger = conf.merger(); + if (merger == "") { + merger = "default"; + } if (!MergerManager::instance().get(merger, _merger)) { LOG(ERROR) << "Failed get merger: " << merger; return ERR_INTERNAL_FAILURE; @@ -24,6 +27,7 @@ int InferService::init(const comcfg::ConfigUnit& conf) { LOG(WARNING) << "Succ get merger: " << merger << " for service: " << _infer_service_format; } + ServerManager& svr_mgr = ServerManager::instance(); if (svr_mgr.add_service_by_format(_infer_service_format) != 0) { LOG(FATAL) @@ -32,14 +36,11 @@ int InferService::init(const comcfg::ConfigUnit& conf) { return ERR_INTERNAL_FAILURE; } - uint32_t default_value = 0; - conf["enable_map_request_to_workflow"].get_uint32(&default_value, 0); - _enable_map_request_to_workflow = (default_value != 0); + _enable_map_request_to_workflow = conf.enable_map_request_to_workflow(); LOG(INFO) << "service[" << _infer_service_format << "], enable_map_request_to_workflow[" << _enable_map_request_to_workflow << "]."; - uint32_t flow_size = conf["workflow"].size(); if (_enable_map_request_to_workflow) { if (_request_to_workflow_map.init( MAX_WORKFLOW_NUM_IN_ONE_SERVICE/*load_factor=80*/) != 0) { @@ -49,31 +50,23 @@ int InferService::init(const comcfg::ConfigUnit& conf) { return ERR_INTERNAL_FAILURE; } int err = 0; - const char* pchar = conf["request_field_key"].to_cstr(&err); - if (err != 0) { + _request_field_key = conf.request_field_key().c_str(); + if (_request_field_key == "") { LOG(FATAL) - << "read request_field_key failed, err_code[" - << err << "]."; + << "read request_field_key failed, request_field_key[" + << _request_field_key << "]."; return ERR_INTERNAL_FAILURE; } - _request_field_key = std::string(pchar); + LOG(INFO) << "service[" << _infer_service_format << "], request_field_key[" << _request_field_key << "]."; - uint32_t request_field_value_size = conf["request_field_value"].size(); - if (request_field_value_size != flow_size) { - LOG(FATAL) - << "flow_size[" << flow_size - << "] not equal request_field_value_size[" - << request_field_value_size << "]."; - return ERR_INTERNAL_FAILURE; - } - - for (uint32_t fi = 0; fi < flow_size; fi++) { + uint32_t value_mapped_workflows_size = conf.value_mapped_workflows_size(); + for (uint32_t fi = 0; fi < value_mapped_workflows_size; fi++) { std::vector tokens; std::vector workflows; - std::string list = conf["workflow"][fi].to_cstr(); + std::string list = conf.value_mapped_workflows(fi).workflow(); boost::split(tokens, list, boost::is_any_of(",")); uint32_t tsize = tokens.size(); for (uint32_t ti = 0; ti < tsize; ++ti) { @@ -89,7 +82,8 @@ int InferService::init(const comcfg::ConfigUnit& conf) { workflow->regist_metric(full_name()); workflows.push_back(workflow); } - const std::string& request_field_value = conf["request_field_value"][fi].to_cstr(); + + const std::string& request_field_value = conf.value_mapped_workflows(fi).request_field_value(); if (_request_to_workflow_map.insert(request_field_value, workflows) == NULL) { LOG(FATAL) << "insert [" << request_field_value << "," @@ -100,9 +94,9 @@ int InferService::init(const comcfg::ConfigUnit& conf) { << "], request_field_value[" << request_field_value << "]."; } } else { + uint32_t flow_size = conf.workflows_size(); for (uint32_t fi = 0; fi < flow_size; fi++) { - const std::string& workflow_name = - conf["workflow"][fi].to_cstr(); + const std::string& workflow_name = conf.workflows(fi); Workflow* workflow = WorkflowManager::instance().item(workflow_name); if (workflow == NULL) { diff --git a/predictor/framework/service.h b/predictor/framework/service.h index c685c8d9256219c0d77aaec8ce901517021dd740..b187c5c0dd753bd93c70fe3663def123061724f8 100644 --- a/predictor/framework/service.h +++ b/predictor/framework/service.h @@ -26,7 +26,7 @@ public: _request_to_workflow_map.clear(); } - int init(const comcfg::ConfigUnit& conf); + int init(const configure::InferService& conf); int deinit() { return 0; } diff --git a/predictor/framework/workflow.cpp b/predictor/framework/workflow.cpp index 4c343179a4ae39e3f1ede27ffabaa483acde6b59..4a6c06539fac2bc24c58b8473522dbfe95ab2f2d 100644 --- a/predictor/framework/workflow.cpp +++ b/predictor/framework/workflow.cpp @@ -6,20 +6,11 @@ namespace baidu { namespace paddle_serving { namespace predictor { -int Workflow::init(const comcfg::ConfigUnit& conf) { - const std::string& name = conf["name"].to_cstr(); - const std::string& path = conf["path"].to_cstr(); - const std::string& file = conf["file"].to_cstr(); - comcfg::Configure wf_conf; - if (wf_conf.load(path.c_str(), file.c_str()) != 0) { - LOG(ERROR) - << "Failed load workflow, conf:" - << path << "/" << file << "!"; - return -1; - } - _type = wf_conf["workflow_type"].to_cstr(); +int Workflow::init(const configure::Workflow& conf) { + const std::string& name = conf.name(); + _type = conf.workflow_type(); _name = name; - if (_dag.init(wf_conf, name) != 0) { + if (_dag.init(conf, name) != 0) { LOG(ERROR) << "Failed initialize dag: " << _name; return -1; } diff --git a/predictor/framework/workflow.h b/predictor/framework/workflow.h index 9c4f77436be6a2830cb6d5255f0b1876cbce255e..552e004e61c099e9481fdfa019dc7de9d03c35af 100644 --- a/predictor/framework/workflow.h +++ b/predictor/framework/workflow.h @@ -23,7 +23,7 @@ public: // Each workflow object corresponds to an independent // configure file, so you can share the object between // different apps. - int init(const comcfg::ConfigUnit& conf); + int init(const configure::Workflow& conf); DagView* fetch_dag_view(const std::string& service_name); diff --git a/predictor/op/op.h b/predictor/op/op.h index 73928f6862bc755e6cc6770b0fd468c2c1fcc1bd..0abbff18dfe734d9d2a986c75d1a9d9f759497f1 100644 --- a/predictor/op/op.h +++ b/predictor/op/op.h @@ -148,7 +148,7 @@ public: virtual int inference() = 0; // ------------------ Conf Interface ------------------- - virtual void* create_config(const comcfg::ConfigUnit& conf) { return NULL; } + virtual void* create_config(const configure::DAGNode& conf) { return NULL; } virtual void delete_config(void* conf) { } diff --git a/predictor/src/pdserving.cpp b/predictor/src/pdserving.cpp index eda0c0ff686def6cad5c044865a87101c3da0a6d..dad683d81452535ee47f10beb046d28c0fefb332 100644 --- a/predictor/src/pdserving.cpp +++ b/predictor/src/pdserving.cpp @@ -29,6 +29,9 @@ using baidu::paddle_serving::predictor::FLAGS_resource_file; using baidu::paddle_serving::predictor::FLAGS_reload_interval_s; using baidu::paddle_serving::predictor::FLAGS_port; +using baidu::paddle_serving::configure::InferServiceConf; +using baidu::paddle_serving::configure::read_proto_conf; + void print_revision(std::ostream& os, void*) { #if defined(PDSERVING_VERSION) os << PDSERVING_VERSION; @@ -52,15 +55,14 @@ void pthread_worker_start_fn() { } static void g_change_server_port() { - comcfg::Configure conf; - if (conf.load(FLAGS_inferservice_path.c_str(), FLAGS_inferservice_file.c_str()) != 0) { + InferServiceConf conf; + if (read_proto_conf(FLAGS_inferservice_path.c_str(), FLAGS_inferservice_file.c_str(), &conf) != 0) { LOG(WARNING) << "failed to load configure[" << FLAGS_inferservice_path << "," << FLAGS_inferservice_file << "]."; return; } - uint32_t port = 0; - int err = conf["port"].get_uint32(&port, 0); - if (err == 0) { + uint32_t port = conf.port(); + if (port != 0) { FLAGS_port = port; LOG(INFO) << "use configure[" << FLAGS_inferservice_path << "/" << FLAGS_inferservice_file << "] port[" << port << "] instead of flags"; diff --git a/proto_configure/CMakeLists.txt b/proto_configure/CMakeLists.txt index 5cd32c4edd5bddae1751512a6ac62a94a5717a02..89cf870d920af2c86ceea762ce50947dc7d56fa8 100644 --- a/proto_configure/CMakeLists.txt +++ b/proto_configure/CMakeLists.txt @@ -1,5 +1,6 @@ LIST(APPEND protofiles - ${CMAKE_CURRENT_LIST_DIR}/proto/configure.proto + ${CMAKE_CURRENT_LIST_DIR}/proto/server_configure.proto + ${CMAKE_CURRENT_LIST_DIR}/proto/sdk_configure.proto ) PROTOBUF_GENERATE_CPP(configure_proto_srcs configure_proto_hdrs ${protofiles}) diff --git a/proto_configure/proto/configure.proto b/proto_configure/proto/configure.proto deleted file mode 100644 index 7ccf8f5464d940314b737fd2b91996e79a20b85d..0000000000000000000000000000000000000000 --- a/proto_configure/proto/configure.proto +++ /dev/null @@ -1,66 +0,0 @@ -syntax="proto2"; -package baidu.paddle_serving.configure; - -message EngineDesc { - required string type = 1; - required string reloadable_meta = 2; - required string reloadable_type = 3; - required string model_data_path = 4; - required uint32 runtime_thread_num = 5; - required uint32 batch_infer_size = 6; - required uint32 enable_batch_align = 7; - optional string version_file = 8; - optional string version_type = 9; -}; - -message VersionedEngine { - required string name = 1; - repeated EngineDesc versions = 2; - optional EngineDesc default_version = 3; -}; - -// model_toolkit conf -message ModelToolkitConf { - repeated VersionedEngine engines = 1; -}; - -// reource conf -message ResourceConf { - required string model_toolkit_path = 1; - required string model_toolkit_file = 2; -}; - -// DAG node depency info -message DAGNodeDependency { - required string name = 1; - required string mode = 2; -}; - -// DAG Node -message DAGNode { - required string name = 1; - required string type = 2; - repeated DAGNodeDependency dependencies = 3; -}; - -// workflow entry -message Workflow { - required string name = 1; - required string workflow_type = 2; - repeated DAGNode nodes = 3; -}; - -// Workflow conf -message WorkflowConf { - repeated Workflow workflow = 1; -} - -message InferService { - required string name = 1; - repeated string workflow = 2; -}; - -// InferService conf -message InferServiceConf { - repeated InferService service = 1; -}; diff --git a/proto_configure/proto/sdk_configure.proto b/proto_configure/proto/sdk_configure.proto new file mode 100644 index 0000000000000000000000000000000000000000..1c7b9a38e97524ba6f390a04a3d6e2660e7dec2c --- /dev/null +++ b/proto_configure/proto/sdk_configure.proto @@ -0,0 +1,59 @@ +syntax="proto2"; +package baidu.paddle_serving.configure; + +message ConnectionConf { + required uint32 connect_timeout_ms = 1; + required uint32 rpc_timeout_ms = 2; + required uint32 connect_retry_count = 3; + required uint32 max_connection_per_host = 4; + required uint32 hedge_request_timeout_ms = 5; + required uint32 hedge_fetch_retry_count = 6; + required string connection_type = 7; +}; + +message NamingConf { + optional string cluster_filter_strategy = 1; + optional string load_balance_strategy = 2; + optional string cluster = 3; +}; + +message RpcParameter { + // 0-NONE, 1-SNAPPY, 2-GZIP, 3-ZLIB, 4-LZ4 + required uint32 compress_type = 1; + required uint32 package_size = 2; + required string protocol = 3; + required uint32 max_channel_per_request = 4; +}; + +message SplitConf{ + optional string split_tag_name = 1; + optional string tag_candidates = 2; +}; + +message VariantConf { + required string tag = 1; + optional ConnectionConf connection_conf = 2; + optional NamingConf naming_conf = 3; + optional RpcParameter rpc_parameter = 4; + optional SplitConf split_conf = 5; + optional string variant_router = 6; +}; + +message WeightedRandomRenderConf { + required string variant_weight_list = 1; +}; + +message Predictor { + required string name = 1; + required string service_name = 2; + required string endpoint_router = 3; + required WeightedRandomRenderConf weighted_random_render_conf = 4; + repeated VariantConf variants = 5; +}; + +// SDK conf +message SDKConf { + required VariantConf default_variant_conf = 1; + repeated Predictor predictors = 2; +}; + diff --git a/proto_configure/proto/server_configure.proto b/proto_configure/proto/server_configure.proto new file mode 100644 index 0000000000000000000000000000000000000000..32e7384506d00d7d17360e3096f429f8b4443632 --- /dev/null +++ b/proto_configure/proto/server_configure.proto @@ -0,0 +1,89 @@ +syntax="proto2"; +package baidu.paddle_serving.configure; + +message EngineDesc { + required string name = 1; + required string type = 2; + required string reloadable_meta = 3; + required string reloadable_type = 4; + required string model_data_path = 5; + required uint32 runtime_thread_num = 6; + required uint32 batch_infer_size = 7; + required uint32 enable_batch_align = 8; + optional string version_file = 9; + optional string version_type = 10; +}; + +// model_toolkit conf +message ModelToolkitConf { + repeated EngineDesc engines = 1; +}; + +// reource conf +message ResourceConf { + required string model_toolkit_path = 1; + required string model_toolkit_file = 2; +}; + +// DAG node depency info +message DAGNodeDependency { + required string name = 1; + required string mode = 2; +}; + +// DAG Node +message DAGNode { + required string name = 1; + required string type = 2; + repeated DAGNodeDependency dependencies = 3; +}; + +// workflow entry +message Workflow { + required string name = 1; + required string workflow_type = 2; + repeated DAGNode nodes = 3; +}; + +// Workflow conf +message WorkflowConf { + repeated Workflow workflows = 1; +} + + +// request_field_key: specifies use which request field as mapping key (see +// request_field_key in InferService below) +// +// If the value of the user request field specified by `request_field_key` +// matches the value of `request_field_value` in one of the +// ValueMappedWorkflows, the request will be directed to the workflow specified +// in the `workflow` field of that ValueMappedWorkflow +// +message ValueMappedWorkflow { + required string request_field_value = 1; + required string workflow = 2; +}; + +message InferService { + required string name = 1; + optional string merger = 2; + + optional bool enable_map_request_to_workflow = 3 [default = false]; + + // If enable_map_request_to_workfow = true + // + // Each request will be mapped to a workflow according to the value in + // in user request field specified by `request_field_key` (see the + // comments for ValueMappedWorkflow above) + optional string request_field_key = 4; + repeated ValueMappedWorkflow value_mapped_workflows = 5; + + // If enable_map_request_to_workflow = false + repeated string workflows = 6; +}; + +// InferService conf +message InferServiceConf { + optional uint32 port = 1; + repeated InferService services = 2; +}; diff --git a/proto_configure/src/configure_parser.cpp b/proto_configure/src/configure_parser.cpp index ba41caf3d298b2efef54757a04e155685bcb502d..d527b8ee581f43f469de470bf10e4143677e4c82 100644 --- a/proto_configure/src/configure_parser.cpp +++ b/proto_configure/src/configure_parser.cpp @@ -15,7 +15,7 @@ int read_proto_conf(const std::string &conf_path, const std::string &conf_file, google::protobuf::Message *conf) { - std::string file_str = conf_path + conf_file; + std::string file_str = conf_path + "/" + conf_file; int fd = open(file_str.c_str(), O_RDONLY); if (fd == -1) { LOG(WARNING) << "File not found: " << file_str.c_str(); @@ -39,7 +39,7 @@ int write_proto_conf(google::protobuf::Message *message, std::string binary_str; google::protobuf::TextFormat::PrintToString(*message, &binary_str); - std::string file_str = output_path + output_file; + std::string file_str = output_path + "/" + output_file; std::ofstream fout_bin((file_str.c_str())); if (!fout_bin) { LOG(WARNING) << "Open file error: " << file_str.c_str(); diff --git a/proto_configure/tests/test_configure.cpp b/proto_configure/tests/test_configure.cpp index 8f1cac41b9ebafc286c4b12c2097e132a1b4bfa2..d145a0f7cb8953437c85762d04b88c0089b66b0a 100644 --- a/proto_configure/tests/test_configure.cpp +++ b/proto_configure/tests/test_configure.cpp @@ -2,11 +2,11 @@ #include #include #include -#include "configure.pb.h" +#include "server_configure.pb.h" +#include "sdk_configure.pb.h" #include "configure_parser.h" using baidu::paddle_serving::configure::EngineDesc; -using baidu::paddle_serving::configure::VersionedEngine; using baidu::paddle_serving::configure::ModelToolkitConf; using baidu::paddle_serving::configure::ResourceConf; @@ -19,11 +19,21 @@ using baidu::paddle_serving::configure::WorkflowConf; using baidu::paddle_serving::configure::InferService; using baidu::paddle_serving::configure::InferServiceConf; +using baidu::paddle_serving::configure::ConnectionConf; +using baidu::paddle_serving::configure::WeightedRandomRenderConf; +using baidu::paddle_serving::configure::NamingConf; +using baidu::paddle_serving::configure::RpcParameter; +using baidu::paddle_serving::configure::Predictor; +using baidu::paddle_serving::configure::VariantConf; + +using baidu::paddle_serving::configure::SDKConf; + const std::string output_dir = "./conf/"; const std::string model_toolkit_conf_file = "model_toolkit.prototxt"; const std::string resource_conf_file = "resource.prototxt"; const std::string workflow_conf_file = "workflow.prototxt"; const std::string service_conf_file = "service.prototxt"; +const std::string sdk_conf_file = "predictors.protxt"; int test_write_conf() { @@ -31,38 +41,15 @@ int test_write_conf() ModelToolkitConf model_toolkit_conf; // This engine has a default version - VersionedEngine *engine = model_toolkit_conf.add_engines(); + EngineDesc *engine = model_toolkit_conf.add_engines(); engine->set_name("image_classification_resnet"); - EngineDesc *engine_desc = engine->mutable_default_version(); - engine_desc->set_type("FLUID_CPU_NATIVE_V2"); - engine_desc->set_reloadable_meta("./data/model/paddle/fluid_time_file"); - engine_desc->set_reloadable_type("timestamp_ne"); - engine_desc->set_model_data_path("./data/model/paddle/fluid/SE_ResNeXt50_32x4d"); - engine_desc->set_runtime_thread_num(0); - engine_desc->set_batch_infer_size(0); - engine_desc->set_enable_batch_align(0); - - // This engine has two versioned branches - engine = model_toolkit_conf.add_engines(); - engine->set_name("image_classification_resnet_versioned"); - // Version 1 - engine_desc = engine->add_versions(); - engine_desc->set_type("FLUID_CPU_NATIVE_DIR"); - engine_desc->set_reloadable_meta("./data/model/paddle/fluid_time_file"); - engine_desc->set_reloadable_type("timestamp_ne"); - engine_desc->set_model_data_path("./data/model/paddle/fluid/SE_ResNeXt50_32x4d"); - engine_desc->set_runtime_thread_num(0); - engine_desc->set_batch_infer_size(0); - engine_desc->set_enable_batch_align(0); - // Version 2 - engine_desc = engine->add_versions(); - engine_desc->set_type("FLUID_CPU_NATIVE_DIR"); - engine_desc->set_reloadable_meta("./data/model/paddle/fluid_time_file_2"); - engine_desc->set_reloadable_type("timestamp_ne_2"); - engine_desc->set_model_data_path("./data/model/paddle/fluid/SE_ResNeXt50_32x4d_2"); - engine_desc->set_runtime_thread_num(0); - engine_desc->set_batch_infer_size(0); - engine_desc->set_enable_batch_align(0); + engine->set_type("FLUID_CPU_NATIVE_DIR"); + engine->set_reloadable_meta("./data/model/paddle/fluid_time_file"); + engine->set_reloadable_type("timestamp_ne"); + engine->set_model_data_path("./data/model/paddle/fluid/SE_ResNeXt50_32x4d"); + engine->set_runtime_thread_num(0); + engine->set_batch_infer_size(0); + engine->set_enable_batch_align(0); int ret = baidu::paddle_serving::configure::write_proto_conf(&model_toolkit_conf, output_dir, model_toolkit_conf_file); if (ret != 0) { @@ -72,7 +59,7 @@ int test_write_conf() // resource conf ResourceConf resource_conf; resource_conf.set_model_toolkit_path(output_dir); - resource_conf.set_model_toolkit_file("resource.prototxt"); + resource_conf.set_model_toolkit_file("model_toolkit.prototxt"); ret = baidu::paddle_serving::configure::write_proto_conf(&resource_conf, output_dir, resource_conf_file); if (ret != 0) { return ret; @@ -80,7 +67,7 @@ int test_write_conf() // workflow entries conf WorkflowConf workflow_conf; - Workflow *workflow = workflow_conf.add_workflow(); + Workflow *workflow = workflow_conf.add_workflows(); workflow->set_name("workflow1"); workflow->set_workflow_type("Sequence"); @@ -102,7 +89,7 @@ int test_write_conf() node_dependency->set_name("image_classify_op"); node_dependency->set_mode("RO"); - workflow = workflow_conf.add_workflow(); + workflow = workflow_conf.add_workflows(); workflow->set_name("workflow2"); workflow->set_workflow_type("Sequence"); @@ -116,19 +103,62 @@ int test_write_conf() } InferServiceConf infer_service_conf; - InferService *infer_service = infer_service_conf.add_service(); + infer_service_conf.set_port(0); + InferService *infer_service = infer_service_conf.add_services(); infer_service->set_name("ImageClassifyService"); - infer_service->add_workflow("workflow1"); - infer_service->add_workflow("workflow2"); + infer_service->add_workflows("workflow1"); + infer_service->add_workflows("workflow2"); - infer_service = infer_service_conf.add_service(); + infer_service = infer_service_conf.add_services(); infer_service->set_name("BuiltinDenseFormatService"); - infer_service->add_workflow("workflow2"); + infer_service->add_workflows("workflow2"); ret = baidu::paddle_serving::configure::write_proto_conf(&infer_service_conf, output_dir, service_conf_file); if (ret != 0) { return ret; } + + SDKConf sdk_conf; + VariantConf *default_variant_conf = sdk_conf.mutable_default_variant_conf(); + default_variant_conf->set_tag("default"); + + ConnectionConf *connection_conf = default_variant_conf->mutable_connection_conf(); + connection_conf->set_connect_timeout_ms(2000); + connection_conf->set_rpc_timeout_ms(20000); + connection_conf->set_connect_retry_count(2); + connection_conf->set_max_connection_per_host(100); + connection_conf->set_hedge_request_timeout_ms(-1); + connection_conf->set_hedge_fetch_retry_count(2); + connection_conf->set_connection_type("pooled"); + + NamingConf *naming_conf = default_variant_conf->mutable_naming_conf(); + naming_conf->set_cluster_filter_strategy("Default"); + naming_conf->set_load_balance_strategy("la"); + + RpcParameter *rpc_parameter = default_variant_conf->mutable_rpc_parameter(); + rpc_parameter->set_compress_type(0); + rpc_parameter->set_package_size(20); + rpc_parameter->set_protocol("baidu_std"); + rpc_parameter->set_max_channel_per_request(3); + + Predictor *predictor = sdk_conf.add_predictors(); + predictor->set_name("ximage"); + predictor->set_service_name("baidu.paddle_serving.predictor.image_classification.ImageClassifyService"); + predictor->set_endpoint_router("WeightedRandomRender"); + + WeightedRandomRenderConf *weighted_random_render_conf = predictor->mutable_weighted_random_render_conf(); + weighted_random_render_conf->set_variant_weight_list("50"); + + VariantConf *variant_conf = predictor->add_variants(); + variant_conf->set_tag("var1"); + naming_conf = variant_conf->mutable_naming_conf(); + naming_conf->set_cluster("list://127.0.0.1:8010"); + + ret = baidu::paddle_serving::configure::write_proto_conf(&sdk_conf, output_dir, sdk_conf_file); + if (ret != 0) { + return ret; + } + return 0; } @@ -164,6 +194,13 @@ int test_read_conf() return -1; } + SDKConf sdk_conf; + ret = baidu::paddle_serving::configure::read_proto_conf(output_dir, sdk_conf_file, &sdk_conf); + if (ret != 0) { + std::cout << "Read conf fail: " << sdk_conf_file << std::endl; + return -1; + } + return 0; } diff --git a/sdk-cpp/CMakeLists.txt b/sdk-cpp/CMakeLists.txt index 2ac0ec8d5dcf1040a642d38e8b1f857879163ec4..c368059e5ccb1b140eec5e834364b56a44a0bfbf 100644 --- a/sdk-cpp/CMakeLists.txt +++ b/sdk-cpp/CMakeLists.txt @@ -1,21 +1,25 @@ include(src/CMakeLists.txt) include(proto/CMakeLists.txt) add_library(sdk-cpp ${sdk_cpp_srcs}) -add_dependencies(sdk-cpp configure) +add_dependencies(sdk-cpp configure proto_configure) target_include_directories(sdk-cpp PUBLIC ${CMAKE_CURRENT_LIST_DIR}/include ${CMKAE_CURRENT_BINARY_DIR}/ + ${CMAKE_CURRENT_BINARY_DIR}/../proto_configure ${CMAKE_CURRENT_LIST_DIR}/../configure + ${CMAKE_CURRENT_LIST_DIR}/../proto_configure/include ${CMAKE_CURRENT_LIST_DIR}/../ullib/include ${CMAKE_CURRENT_BINARY_DIR}/../bsl/include ) -target_link_libraries(sdk-cpp brpc configure protobuf leveldb) +target_link_libraries(sdk-cpp brpc configure proto_configure protobuf leveldb) add_executable(ximage ${CMAKE_CURRENT_LIST_DIR}/demo/ximage.cpp) target_include_directories(ximage PUBLIC ${CMAKE_CURRENT_LIST_DIR}/include ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_BINARY_DIR}/../proto_configure ${CMAKE_CURRENT_LIST_DIR}/../configure + ${CMAKE_CURRENT_LIST_DIR}/../proto_configure/include ${CMAKE_CURRENT_LIST_DIR}/../ullib/include ${CMAKE_CURRENT_BINARY_DIR}/../bsl/include) target_link_libraries(ximage sdk-cpp -lpthread -lcrypto -lm -lrt -lssl -ldl @@ -25,7 +29,9 @@ add_executable(mapcnn_dense ${CMAKE_CURRENT_LIST_DIR}/demo/mapcnn_dense.cpp) target_include_directories(mapcnn_dense PUBLIC ${CMAKE_CURRENT_LIST_DIR}/include ${CMAKE_CURRENT_BINARY_DIR}/ + ${CMAKE_CURRENT_BINARY_DIR}/../proto_configure ${CMAKE_CURRENT_LIST_DIR}/../configure + ${CMAKE_CURRENT_LIST_DIR}/../proto_configure/include ${CMAKE_CURRENT_LIST_DIR}/../ullib/include ${CMAKE_CURRENT_BINARY_DIR}/../bsl/include) target_link_libraries(mapcnn_dense sdk-cpp -lpthread -lcrypto -lm -lrt -lssl @@ -35,7 +41,9 @@ add_executable(mapcnn_sparse ${CMAKE_CURRENT_LIST_DIR}/demo/mapcnn_sparse.cpp) target_include_directories(mapcnn_sparse PUBLIC ${CMAKE_CURRENT_LIST_DIR}/include ${CMAKE_CURRENT_BINARY_DIR}/ + ${CMAKE_CURRENT_BINARY_DIR}/../proto_configure ${CMAKE_CURRENT_LIST_DIR}/../configure + ${CMAKE_CURRENT_LIST_DIR}/../proto_configure/include ${CMAKE_CURRENT_LIST_DIR}/../ullib/include ${CMAKE_CURRENT_BINARY_DIR}/../bsl/include) target_link_libraries(mapcnn_sparse sdk-cpp -lpthread -lcrypto -lm -lrt -lssl diff --git a/sdk-cpp/demo/ximage.cpp b/sdk-cpp/demo/ximage.cpp index 12b7bc093440cb5416b864e9e03e5320cb6cb838..74aaee25bb0b21c80bb750c709e775798c96ffd4 100644 --- a/sdk-cpp/demo/ximage.cpp +++ b/sdk-cpp/demo/ximage.cpp @@ -118,7 +118,7 @@ void print_res( int main(int argc, char** argv) { PredictorApi api; - if (api.create("./conf", "predictors.conf") != 0) { + if (api.create("./conf", "predictors.prototxt") != 0) { LOG(FATAL) << "Failed create predictors api!"; return -1; } diff --git a/sdk-cpp/include/abtest.h b/sdk-cpp/include/abtest.h index 0123b7f8e9a5bce3ed5f28ff7f398f90dc82b123..3e00097054bb81cf6a33f4e71d9e789226a4943d 100644 --- a/sdk-cpp/include/abtest.h +++ b/sdk-cpp/include/abtest.h @@ -18,6 +18,7 @@ #include "stub.h" #include "common.h" #include "factory.h" +#include namespace baidu { namespace paddle_serving { @@ -35,7 +36,7 @@ public: virtual ~EndpointRouterBase() {} virtual int initialize( - const comcfg::ConfigUnit& conf) = 0; + const google::protobuf::Message& conf) = 0; virtual Variant* route(const VariantList&) = 0; @@ -56,7 +57,7 @@ public: ~WeightedRandomRender() {} int initialize( - const comcfg::ConfigUnit& conf); + const google::protobuf::Message& conf); Variant* route(const VariantList&); diff --git a/sdk-cpp/include/common.h b/sdk-cpp/include/common.h index c91e2e51adf5d6bd6131eab76f8a0e541723c029..dfcd9108be849da1f30da3ef83df4b4a78a490cc 100644 --- a/sdk-cpp/include/common.h +++ b/sdk-cpp/include/common.h @@ -42,6 +42,8 @@ #include #include "Configure.h" +#include "sdk_configure.pb.h" +#include "configure_parser.h" #include "utils.h" diff --git a/sdk-cpp/include/config_manager.h b/sdk-cpp/include/config_manager.h index 7d08c09206effd50f39524bcd053cd2d95448018..5696092b6068b6ccb7242f11e8921adf1160b6a4 100644 --- a/sdk-cpp/include/config_manager.h +++ b/sdk-cpp/include/config_manager.h @@ -62,17 +62,17 @@ public: private: int init_one_variant( - const comcfg::ConfigUnit& conf, + const configure::VariantConf& conf, VariantInfo& var); int init_one_endpoint( - const comcfg::ConfigUnit& conf, + const configure::Predictor& conf, EndpointInfo& ep, const VariantInfo& default_var); int merge_variant( const VariantInfo& default_var, - const comcfg::ConfigUnit& conf, + const configure::VariantConf& conf, VariantInfo& merged_var); int parse_tag_values( diff --git a/sdk-cpp/include/endpoint_config.h b/sdk-cpp/include/endpoint_config.h index 4c272aac6c2bac81e633f9128cce6ffd11f2b20b..cca2f6f1b912252fd145e8c123949ac25442f69c 100644 --- a/sdk-cpp/include/endpoint_config.h +++ b/sdk-cpp/include/endpoint_config.h @@ -27,12 +27,7 @@ namespace sdk_cpp { #define PARSE_CONF_ITEM(conf, item, name, fail) \ do { \ try { \ - item.set(conf[name]); \ - } catch (comcfg::NoSuchKeyException& e) { \ - LOG(INFO) << "Not found key in configue: " << name;\ - } catch (comcfg::ConfigException& e) { \ - LOG(FATAL) << "Error config, key: " << name; \ - return fail; \ + item.set(conf.name()); \ } catch (...) { \ LOG(FATAL) << "Unkown error accurs when load config";\ return fail; \ @@ -60,55 +55,10 @@ template struct ConfigItem { T value; bool init; ConfigItem() : init(false) {} - void set(const comcfg::ConfigUnit& unit) { - set_impl(type_traits::tag, unit); + void set(const T& unit) { + value = unit; init = true; } - - void set_impl(type_traits&, - const comcfg::ConfigUnit& unit) { - value = unit.to_int16(); - } - - void set_impl(type_traits&, - const comcfg::ConfigUnit& unit) { - value = unit.to_int32(); - } - - void set_impl(type_traits&, - const comcfg::ConfigUnit& unit) { - value = unit.to_int64(); - } - - void set_impl(type_traits&, - const comcfg::ConfigUnit& unit) { - value = unit.to_uint16(); - } - - void set_impl(type_traits&, - const comcfg::ConfigUnit& unit) { - value = unit.to_uint32(); - } - - void set_impl(type_traits&, - const comcfg::ConfigUnit& unit) { - value = unit.to_uint64(); - } - - void set_impl(type_traits&, - const comcfg::ConfigUnit& unit) { - value = unit.to_float(); - } - - void set_impl(type_traits&, - const comcfg::ConfigUnit& unit) { - value = unit.to_double(); - } - - void set_impl(type_traits&, - const comcfg::ConfigUnit& unit) { - value = unit.to_cstr(); - } }; struct Connection { diff --git a/sdk-cpp/src/abtest.cpp b/sdk-cpp/src/abtest.cpp index c9408f956e83b00fb6c133fb94cbaecaf64f022a..cd7ea1b2505970c054c2d604026560983e262e35 100644 --- a/sdk-cpp/src/abtest.cpp +++ b/sdk-cpp/src/abtest.cpp @@ -20,11 +20,14 @@ namespace sdk_cpp { int WeightedRandomRender::initialize( - const comcfg::ConfigUnit& conf) { + const google::protobuf::Message& conf) { srand((unsigned)time(NULL)); try { + const configure::WeightedRandomRenderConf &weighted_random_render_conf = + dynamic_cast(conf); + std::string weights - = conf["VariantWeightList"].to_cstr(); + = weighted_random_render_conf.variant_weight_list(); std::vector splits; if (str_split(weights, WEIGHT_SEPERATOR, &splits) != 0) { @@ -57,7 +60,7 @@ int WeightedRandomRender::initialize( LOG(INFO) << "Succ read weights list: " << weights << ", count: " << _variant_weight_list.size() << ", normalized: " << _normalized_sum; - } catch (comcfg::ConfigException& e) { + } catch (std::bad_cast& e) { LOG(FATAL) << "Failed init WeightedRandomRender" << "from configure, err:" << e.what(); return -1; diff --git a/sdk-cpp/src/config_manager.cpp b/sdk-cpp/src/config_manager.cpp index 851cb446b2683bf8535f45de51258f720a8891f6..e24bcaad32c9dd599e0fe44b71a2c1706c1f69c8 100644 --- a/sdk-cpp/src/config_manager.cpp +++ b/sdk-cpp/src/config_manager.cpp @@ -19,6 +19,8 @@ namespace baidu { namespace paddle_serving { namespace sdk_cpp { +using configure::SDKConf; + int EndpointConfigManager::create(const char* path, const char* file) { _endpoint_config_path = path; _endpoint_config_file = file; @@ -33,10 +35,11 @@ int EndpointConfigManager::create(const char* path, const char* file) { int EndpointConfigManager::load() { try { - comcfg::Configure conf; - if (conf.load( + SDKConf sdk_conf; + if (configure::read_proto_conf( _endpoint_config_path.c_str(), - _endpoint_config_file.c_str()) != 0) { + _endpoint_config_file.c_str(), + &sdk_conf) != 0) { LOG(FATAL) << "Failed initialize endpoint list" << ", config: " << _endpoint_config_path @@ -45,16 +48,16 @@ int EndpointConfigManager::load() { } VariantInfo default_var; - if (init_one_variant(conf["DefaultVariantInfo"], + if (init_one_variant(sdk_conf.default_variant_conf(), default_var) != 0) { LOG(FATAL) << "Failed read default var conf"; return -1; } - uint32_t ep_size = conf["Predictor"].size(); + uint32_t ep_size = sdk_conf.predictors_size(); for (uint32_t ei = 0; ei < ep_size; ++ei) { EndpointInfo ep; - if (init_one_endpoint(conf["Predictor"][ei], ep, + if (init_one_endpoint(sdk_conf.predictors(ei), ep, default_var) != 0) { LOG(FATAL) << "Failed read endpoint info at: " << ei; @@ -88,36 +91,41 @@ int EndpointConfigManager::load() { } int EndpointConfigManager::init_one_endpoint( - const comcfg::ConfigUnit& conf, EndpointInfo& ep, + const configure::Predictor& conf, EndpointInfo& ep, const VariantInfo& dft_var) { try { // name - ep.endpoint_name = conf["name"].to_cstr(); + ep.endpoint_name = conf.name(); // stub - ep.stub_service = conf["service_name"].to_cstr(); + ep.stub_service = conf.service_name(); // abtest ConfigItem ep_router; - PARSE_CONF_ITEM(conf, ep_router, "endpoint_router", -1); + PARSE_CONF_ITEM(conf, ep_router, endpoint_router, -1); if (ep_router.init) { - std::string endpoint_router_info - = conf["endpoint_router"].to_cstr(); + if (ep_router.value != "WeightedRandomRenderConf") { + LOG(FATAL) << "endpointer_router unrecognized " << ep_router.value; + return -1; + } + EndpointRouterBase* router = EndpointRouterFactory::instance().generate_object( ep_router.value); - if (!router || router->initialize( - conf[endpoint_router_info.c_str()]) != 0) { + + const configure::WeightedRandomRenderConf &router_conf = + conf.weighted_random_render_conf(); + if (!router || router->initialize(router_conf) != 0) { LOG(FATAL) << "Failed fetch valid ab test strategy" - << ", name:" << endpoint_router_info; + << ", name:" << ep_router.value; return -1; } ep.ab_test = router; } // varlist - uint32_t var_size = conf["VariantInfo"].size(); + uint32_t var_size = conf.variants_size(); for (uint32_t vi = 0; vi < var_size; ++vi) { VariantInfo var; - if (merge_variant(dft_var, conf["VariantInfo"][vi], + if (merge_variant(dft_var, conf.variants(vi), var) != 0) { LOG(FATAL) << "Failed merge variant info at: " << vi; @@ -146,54 +154,54 @@ int EndpointConfigManager::init_one_endpoint( } int EndpointConfigManager::init_one_variant( - const comcfg::ConfigUnit& conf, VariantInfo& var) { + const configure::VariantConf& conf, VariantInfo& var) { try { // Connect - const comcfg::ConfigUnit& conn = conf["Connection"]; + const configure::ConnectionConf& conn = conf.connection_conf(); PARSE_CONF_ITEM(conn, var.connection.tmo_conn, - "ConnectTimeoutMilliSec", -1); + connect_timeout_ms, -1); PARSE_CONF_ITEM(conn, var.connection.tmo_rpc, - "RpcTimeoutMilliSec", -1); + rpc_timeout_ms, -1); PARSE_CONF_ITEM(conn, var.connection.tmo_hedge, - "HedgeRequestTimeoutMilliSec", -1); + hedge_request_timeout_ms, -1); PARSE_CONF_ITEM(conn, var.connection.cnt_retry_conn, - "ConnectRetryCount", -1); + connect_retry_count, -1); PARSE_CONF_ITEM(conn, var.connection.cnt_retry_hedge, - "HedgeFetchRetryCount", -1); + hedge_fetch_retry_count, -1); PARSE_CONF_ITEM(conn, var.connection.cnt_maxconn_per_host, - "MaxConnectionPerHost", -1); + max_connection_per_host, -1); PARSE_CONF_ITEM(conn, var.connection.type_conn, - "ConnectionType", -1); + connection_type, -1); // Naming - const comcfg::ConfigUnit& name = conf["NamingInfo"]; + const configure::NamingConf& name = conf.naming_conf(); PARSE_CONF_ITEM(name, var.naminginfo.cluster_naming, - "Cluster", -1); + cluster, -1); PARSE_CONF_ITEM(name, var.naminginfo.load_balancer, - "LoadBalanceStrategy", -1); + load_balance_strategy, -1); PARSE_CONF_ITEM(name, var.naminginfo.cluster_filter, - "ClusterFilterStrategy", -1); + cluster_filter_strategy, -1); // Rpc - const comcfg::ConfigUnit& params = conf["RpcParameter"]; + const configure::RpcParameter& params = conf.rpc_parameter(); PARSE_CONF_ITEM(params, var.parameters.protocol, - "Protocol", -1); + protocol, -1); PARSE_CONF_ITEM(params, var.parameters.compress_type, - "CompressType", -1); + compress_type, -1); PARSE_CONF_ITEM(params, var.parameters.package_size, - "PackageSize", -1); + package_size, -1); PARSE_CONF_ITEM(params, var.parameters.max_channel, - "MaxChannelPerRequest", -1); + max_channel_per_request, -1); // Split - const comcfg::ConfigUnit& splits = conf["SplitInfo"]; + const configure::SplitConf& splits = conf.split_conf(); PARSE_CONF_ITEM(splits, var.splitinfo.split_tag, - "split_tag_name", -1); + split_tag_name, -1); PARSE_CONF_ITEM(splits, var.splitinfo.tag_cands_str, - "tag_candidates", -1); + tag_candidates, -1); if (parse_tag_values(var.splitinfo) != 0) { LOG(FATAL) << "Failed parse tag_values:" << var.splitinfo.tag_cands_str.value; @@ -202,11 +210,11 @@ int EndpointConfigManager::init_one_variant( // tag PARSE_CONF_ITEM(conf, var.parameters.route_tag, - "Tag", -1); + tag, -1); // router ConfigItem var_router; - PARSE_CONF_ITEM(conf, var_router, "variant_router", -1); + PARSE_CONF_ITEM(conf, var_router, variant_router, -1); if (var_router.init) { VariantRouterBase* router = VariantRouterFactory::instance().generate_object( @@ -230,7 +238,7 @@ int EndpointConfigManager::init_one_variant( int EndpointConfigManager::merge_variant( const VariantInfo& default_var, - const comcfg::ConfigUnit& conf, + const configure::VariantConf& conf, VariantInfo& merged_var) { merged_var = default_var;