From 57fc01224058e6f360761050e15eaa98c2253c59 Mon Sep 17 00:00:00 2001 From: guru4elephant Date: Mon, 10 Feb 2020 12:32:31 +0800 Subject: [PATCH] fix predictor initialization bug --- core/general-client/include/general_model.h | 2 + core/general-client/src/general_model.cpp | 10 ++++ .../src/pybind_general_model.cpp | 3 ++ core/sdk-cpp/include/config_manager.h | 4 ++ core/sdk-cpp/include/predictor_sdk.h | 7 +++ core/sdk-cpp/src/config_manager.cpp | 47 +++++++++++++++++++ core/sdk-cpp/src/predictor_sdk.cpp | 43 +++++++++++++++++ python/examples/imdb/benchmark.py | 2 +- python/paddle_serving_client/__init__.py | 9 +--- 9 files changed, 119 insertions(+), 8 deletions(-) diff --git a/core/general-client/include/general_model.h b/core/general-client/include/general_model.h index acb5c347..8f0a8247 100644 --- a/core/general-client/include/general_model.h +++ b/core/general-client/include/general_model.h @@ -50,6 +50,8 @@ class PredictorClient { void set_predictor_conf(const std::string& conf_path, const std::string& conf_file); + int create_predictor_by_desc(const std::string & sdk_desc); + int create_predictor(); int destroy_predictor(); diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index 9ba5ab9c..9552c958 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -88,7 +88,17 @@ int PredictorClient::destroy_predictor() { _api.destroy(); } +int PredictorClient::create_predictor_by_desc(const std::string & sdk_desc) { + if (_api.create(sdk_desc) != 0) { + LOG(ERROR) << "Predictor Creation Failed"; + return -1; + } + _api.thrd_initialize(); +} + int PredictorClient::create_predictor() { + VLOG(2) << "Predictor path: " << _predictor_path + << " predictor file: " << _predictor_conf; if (_api.create(_predictor_path.c_str(), _predictor_conf.c_str()) != 0) { LOG(ERROR) << "Predictor Creation Failed"; return -1; diff --git a/core/general-client/src/pybind_general_model.cpp b/core/general-client/src/pybind_general_model.cpp index d1b01412..898b0635 100644 --- a/core/general-client/src/pybind_general_model.cpp +++ b/core/general-client/src/pybind_general_model.cpp @@ -41,6 +41,9 @@ PYBIND11_MODULE(serving_client, m) { const std::string &conf_file) { self.set_predictor_conf(conf_path, conf_file); }) + .def("create_predictor_by_desc", + [](PredictorClient &self, const std::string & sdk_desc) { + self.create_predictor_by_desc(sdk_desc); }) .def("create_predictor", [](PredictorClient &self) { self.create_predictor(); }) .def("destroy_predictor", diff --git a/core/sdk-cpp/include/config_manager.h b/core/sdk-cpp/include/config_manager.h index 7eb409a5..44134716 100644 --- a/core/sdk-cpp/include/config_manager.h +++ b/core/sdk-cpp/include/config_manager.h @@ -32,6 +32,10 @@ class EndpointConfigManager { EndpointConfigManager() : _last_update_timestamp(0), _current_endpointmap_id(1) {} + int create(const std::string & sdk_desc_str); + + int load(const std::string & sdk_desc_str); + int create(const char* path, const char* file); int load(); diff --git a/core/sdk-cpp/include/predictor_sdk.h b/core/sdk-cpp/include/predictor_sdk.h index 34bf8db7..bcc8ab83 100644 --- a/core/sdk-cpp/include/predictor_sdk.h +++ b/core/sdk-cpp/include/predictor_sdk.h @@ -31,6 +31,8 @@ class PredictorApi { int register_all(); + int create(const std::string & sdk_desc_str); + int create(const char* path, const char* file); int thrd_initialize(); @@ -47,6 +49,11 @@ class PredictorApi { } Predictor* fetch_predictor(std::string ep_name) { + std::map::iterator iter; + VLOG(2) << "going to print predictor names"; + for (iter = _endpoints.begin(); iter != _endpoints.end(); ++iter) { + VLOG(2) << "name: " << iter->first; + } std::map::iterator it = _endpoints.find(ep_name); if (it == _endpoints.end() || !it->second) { LOG(ERROR) << "Failed fetch predictor:" diff --git a/core/sdk-cpp/src/config_manager.cpp b/core/sdk-cpp/src/config_manager.cpp index aee1a2d4..8bc9f951 100644 --- a/core/sdk-cpp/src/config_manager.cpp +++ b/core/sdk-cpp/src/config_manager.cpp @@ -26,6 +26,13 @@ namespace sdk_cpp { using configure::SDKConf; +int EndpointConfigManager::create(const std::string& sdk_desc_str) { + if (load(sdk_desc_str) != 0) { + LOG(ERROR) << "Failed reload endpoint config"; + return -1; + } +} + int EndpointConfigManager::create(const char* path, const char* file) { _endpoint_config_path = path; _endpoint_config_file = file; @@ -38,6 +45,46 @@ int EndpointConfigManager::create(const char* path, const char* file) { return 0; } +int EndpointConfigManager::load(const std::string& sdk_desc_str) { + try { + SDKConf sdk_conf; + sdk_conf.ParseFromString(sdk_desc_str); + VariantInfo default_var; + if (init_one_variant(sdk_conf.default_variant_conf(), default_var) != 0) { + LOG(ERROR) << "Failed read default var conf"; + return -1; + } + + uint32_t ep_size = sdk_conf.predictors_size(); + for (uint32_t ei = 0; ei < ep_size; ++ei) { + EndpointInfo ep; + if (init_one_endpoint(sdk_conf.predictors(ei), ep, default_var) != 0) { + LOG(ERROR) << "Failed read endpoint info at: " << ei; + return -1; + } + + std::map::iterator it; + if (_ep_map.find(ep.endpoint_name) != _ep_map.end()) { + LOG(ERROR) << "Cannot insert duplicated endpoint" + << ", ep name: " << ep.endpoint_name; + } + + std::pair::iterator, bool> r = + _ep_map.insert(std::make_pair(ep.endpoint_name, ep)); + if (!r.second) { + LOG(ERROR) << "Failed insert endpoint, name" << ep.endpoint_name; + return -1; + } + } + } catch (std::exception& e) { + LOG(ERROR) << "Failed load configure" << e.what(); + return -1; + } + LOG(INFO) << "Success reload endpoint config file, id: " + << _current_endpointmap_id; + return 0; +} + int EndpointConfigManager::load() { try { SDKConf sdk_conf; diff --git a/core/sdk-cpp/src/predictor_sdk.cpp b/core/sdk-cpp/src/predictor_sdk.cpp index 8a8575a9..ae976446 100644 --- a/core/sdk-cpp/src/predictor_sdk.cpp +++ b/core/sdk-cpp/src/predictor_sdk.cpp @@ -30,6 +30,49 @@ int PredictorApi::register_all() { return 0; } +int PredictorApi::create(const std::string & api_desc_str) { + VLOG(2) << api_desc_str; + if (register_all() != 0) { + LOG(ERROR) << "Failed do register all!"; + return -1; + } + + if (_config_manager.create(api_desc_str) != 0) { + LOG(ERROR) << "Failed create config manager from desc string :" + << api_desc_str; + return -1; + } + + const std::map& map = _config_manager.config(); + std::map::const_iterator it; + for (it = map.begin(); it != map.end(); ++it) { + const EndpointInfo& ep_info = it->second; + Endpoint* ep = new (std::nothrow) Endpoint(); + if (ep->initialize(ep_info) != 0) { + LOG(ERROR) << "Failed intialize endpoint:" << ep_info.endpoint_name; + return -1; + } + + if (_endpoints.find(ep_info.endpoint_name) != _endpoints.end()) { + LOG(ERROR) << "Cannot insert duplicated endpoint:" + << ep_info.endpoint_name; + return -1; + } + + std::pair::iterator, bool> r = + _endpoints.insert(std::make_pair(ep_info.endpoint_name, ep)); + if (!r.second) { + LOG(ERROR) << "Failed insert endpoint:" << ep_info.endpoint_name; + return -1; + } + + LOG(INFO) << "Succ create endpoint instance with name: " + << ep_info.endpoint_name; + } + + return 0; +} + int PredictorApi::create(const char* path, const char* file) { if (register_all() != 0) { LOG(ERROR) << "Failed do register all!"; diff --git a/python/examples/imdb/benchmark.py b/python/examples/imdb/benchmark.py index 1931d759..71c77092 100644 --- a/python/examples/imdb/benchmark.py +++ b/python/examples/imdb/benchmark.py @@ -63,6 +63,6 @@ if __name__ == '__main__': thread_runner = MultiThreadRunner() result = thread_runner.run(predict, int(sys.argv[3]), resource) - print(result[-1]) + print("{}\t{}".format(sys.argv[3], sum(result[-1]) / len(result[-1]))) print("{}\t{}".format(sys.argv[3], sum(result[2]) / 1000.0 / 1000.0 / len(result[2]))) diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 260f0869..88ee74a6 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -36,7 +36,7 @@ class SDKConfig(object): predictor_desc.service_name = \ "baidu.paddle_serving.predictor.general_model.GeneralModelService" predictor_desc.endpoint_router = "WeightedRandomRender" - predictor_desc.weighted_random_render_conf.variant_weight_list = "30" + predictor_desc.weighted_random_render_conf.variant_weight_list = "100" variant_desc = sdk.VariantConf() variant_desc.tag = "var1" @@ -105,12 +105,7 @@ class Client(object): predictor_sdk.set_server_endpoints(endpoints) sdk_desc = predictor_sdk.gen_desc() timestamp = time.asctime(time.localtime(time.time())) - predictor_path = "/tmp/" - predictor_file = "%s_predictor.conf" % timestamp - with open(predictor_path + predictor_file, "w") as fout: - fout.write(sdk_desc) - self.client_handle_.set_predictor_conf(predictor_path, predictor_file) - self.client_handle_.create_predictor() + self.client_handle_.create_predictor(sdk_desc) def get_feed_names(self): return self.feed_names_ -- GitLab