提交 93d25ba7 编写于 作者: G guru4elephant

fix predictor initialization bug

上级 76ae32df
...@@ -50,6 +50,8 @@ class PredictorClient { ...@@ -50,6 +50,8 @@ class PredictorClient {
void set_predictor_conf(const std::string& conf_path, void set_predictor_conf(const std::string& conf_path,
const std::string& conf_file); const std::string& conf_file);
int create_predictor_by_desc(const std::string & sdk_desc);
int create_predictor(); int create_predictor();
int destroy_predictor(); int destroy_predictor();
......
...@@ -88,7 +88,17 @@ int PredictorClient::destroy_predictor() { ...@@ -88,7 +88,17 @@ int PredictorClient::destroy_predictor() {
_api.destroy(); _api.destroy();
} }
int PredictorClient::create_predictor_by_desc(const std::string & sdk_desc) {
if (_api.create(sdk_desc) != 0) {
LOG(ERROR) << "Predictor Creation Failed";
return -1;
}
_api.thrd_initialize();
}
int PredictorClient::create_predictor() { int PredictorClient::create_predictor() {
VLOG(2) << "Predictor path: " << _predictor_path
<< " predictor file: " << _predictor_conf;
if (_api.create(_predictor_path.c_str(), _predictor_conf.c_str()) != 0) { if (_api.create(_predictor_path.c_str(), _predictor_conf.c_str()) != 0) {
LOG(ERROR) << "Predictor Creation Failed"; LOG(ERROR) << "Predictor Creation Failed";
return -1; return -1;
......
...@@ -41,6 +41,9 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -41,6 +41,9 @@ PYBIND11_MODULE(serving_client, m) {
const std::string &conf_file) { const std::string &conf_file) {
self.set_predictor_conf(conf_path, conf_file); self.set_predictor_conf(conf_path, conf_file);
}) })
.def("create_predictor_by_desc",
[](PredictorClient &self, const std::string & sdk_desc) {
self.create_predictor_by_desc(sdk_desc); })
.def("create_predictor", .def("create_predictor",
[](PredictorClient &self) { self.create_predictor(); }) [](PredictorClient &self) { self.create_predictor(); })
.def("destroy_predictor", .def("destroy_predictor",
......
...@@ -32,6 +32,10 @@ class EndpointConfigManager { ...@@ -32,6 +32,10 @@ class EndpointConfigManager {
EndpointConfigManager() EndpointConfigManager()
: _last_update_timestamp(0), _current_endpointmap_id(1) {} : _last_update_timestamp(0), _current_endpointmap_id(1) {}
int create(const std::string & sdk_desc_str);
int load(const std::string & sdk_desc_str);
int create(const char* path, const char* file); int create(const char* path, const char* file);
int load(); int load();
......
...@@ -31,6 +31,8 @@ class PredictorApi { ...@@ -31,6 +31,8 @@ class PredictorApi {
int register_all(); int register_all();
int create(const std::string & sdk_desc_str);
int create(const char* path, const char* file); int create(const char* path, const char* file);
int thrd_initialize(); int thrd_initialize();
...@@ -47,6 +49,11 @@ class PredictorApi { ...@@ -47,6 +49,11 @@ class PredictorApi {
} }
Predictor* fetch_predictor(std::string ep_name) { Predictor* fetch_predictor(std::string ep_name) {
std::map<std::string, Endpoint*>::iterator iter;
VLOG(2) << "going to print predictor names";
for (iter = _endpoints.begin(); iter != _endpoints.end(); ++iter) {
VLOG(2) << "name: " << iter->first;
}
std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name); std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
if (it == _endpoints.end() || !it->second) { if (it == _endpoints.end() || !it->second) {
LOG(ERROR) << "Failed fetch predictor:" LOG(ERROR) << "Failed fetch predictor:"
......
...@@ -26,6 +26,13 @@ namespace sdk_cpp { ...@@ -26,6 +26,13 @@ namespace sdk_cpp {
using configure::SDKConf; using configure::SDKConf;
int EndpointConfigManager::create(const std::string& sdk_desc_str) {
if (load(sdk_desc_str) != 0) {
LOG(ERROR) << "Failed reload endpoint config";
return -1;
}
}
int EndpointConfigManager::create(const char* path, const char* file) { int EndpointConfigManager::create(const char* path, const char* file) {
_endpoint_config_path = path; _endpoint_config_path = path;
_endpoint_config_file = file; _endpoint_config_file = file;
...@@ -38,6 +45,46 @@ int EndpointConfigManager::create(const char* path, const char* file) { ...@@ -38,6 +45,46 @@ int EndpointConfigManager::create(const char* path, const char* file) {
return 0; return 0;
} }
int EndpointConfigManager::load(const std::string& sdk_desc_str) {
try {
SDKConf sdk_conf;
sdk_conf.ParseFromString(sdk_desc_str);
VariantInfo default_var;
if (init_one_variant(sdk_conf.default_variant_conf(), default_var) != 0) {
LOG(ERROR) << "Failed read default var conf";
return -1;
}
uint32_t ep_size = sdk_conf.predictors_size();
for (uint32_t ei = 0; ei < ep_size; ++ei) {
EndpointInfo ep;
if (init_one_endpoint(sdk_conf.predictors(ei), ep, default_var) != 0) {
LOG(ERROR) << "Failed read endpoint info at: " << ei;
return -1;
}
std::map<std::string, EndpointInfo>::iterator it;
if (_ep_map.find(ep.endpoint_name) != _ep_map.end()) {
LOG(ERROR) << "Cannot insert duplicated endpoint"
<< ", ep name: " << ep.endpoint_name;
}
std::pair<std::map<std::string, EndpointInfo>::iterator, bool> r =
_ep_map.insert(std::make_pair(ep.endpoint_name, ep));
if (!r.second) {
LOG(ERROR) << "Failed insert endpoint, name" << ep.endpoint_name;
return -1;
}
}
} catch (std::exception& e) {
LOG(ERROR) << "Failed load configure" << e.what();
return -1;
}
LOG(INFO) << "Success reload endpoint config file, id: "
<< _current_endpointmap_id;
return 0;
}
int EndpointConfigManager::load() { int EndpointConfigManager::load() {
try { try {
SDKConf sdk_conf; SDKConf sdk_conf;
......
...@@ -30,6 +30,49 @@ int PredictorApi::register_all() { ...@@ -30,6 +30,49 @@ int PredictorApi::register_all() {
return 0; return 0;
} }
int PredictorApi::create(const std::string & api_desc_str) {
VLOG(2) << api_desc_str;
if (register_all() != 0) {
LOG(ERROR) << "Failed do register all!";
return -1;
}
if (_config_manager.create(api_desc_str) != 0) {
LOG(ERROR) << "Failed create config manager from desc string :"
<< api_desc_str;
return -1;
}
const std::map<std::string, EndpointInfo>& map = _config_manager.config();
std::map<std::string, EndpointInfo>::const_iterator it;
for (it = map.begin(); it != map.end(); ++it) {
const EndpointInfo& ep_info = it->second;
Endpoint* ep = new (std::nothrow) Endpoint();
if (ep->initialize(ep_info) != 0) {
LOG(ERROR) << "Failed intialize endpoint:" << ep_info.endpoint_name;
return -1;
}
if (_endpoints.find(ep_info.endpoint_name) != _endpoints.end()) {
LOG(ERROR) << "Cannot insert duplicated endpoint:"
<< ep_info.endpoint_name;
return -1;
}
std::pair<std::map<std::string, Endpoint*>::iterator, bool> r =
_endpoints.insert(std::make_pair(ep_info.endpoint_name, ep));
if (!r.second) {
LOG(ERROR) << "Failed insert endpoint:" << ep_info.endpoint_name;
return -1;
}
LOG(INFO) << "Succ create endpoint instance with name: "
<< ep_info.endpoint_name;
}
return 0;
}
int PredictorApi::create(const char* path, const char* file) { int PredictorApi::create(const char* path, const char* file) {
if (register_all() != 0) { if (register_all() != 0) {
LOG(ERROR) << "Failed do register all!"; LOG(ERROR) << "Failed do register all!";
......
...@@ -63,6 +63,6 @@ if __name__ == '__main__': ...@@ -63,6 +63,6 @@ if __name__ == '__main__':
thread_runner = MultiThreadRunner() thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource) result = thread_runner.run(predict, int(sys.argv[3]), resource)
print(result[-1])
print("{}\t{}".format(sys.argv[3], sum(result[-1]) / len(result[-1]))) print("{}\t{}".format(sys.argv[3], sum(result[-1]) / len(result[-1])))
print("{}\t{}".format(sys.argv[3], sum(result[2]) / 1000.0 / 1000.0 / len(result[2]))) print("{}\t{}".format(sys.argv[3], sum(result[2]) / 1000.0 / 1000.0 / len(result[2])))
...@@ -36,7 +36,7 @@ class SDKConfig(object): ...@@ -36,7 +36,7 @@ class SDKConfig(object):
predictor_desc.service_name = \ predictor_desc.service_name = \
"baidu.paddle_serving.predictor.general_model.GeneralModelService" "baidu.paddle_serving.predictor.general_model.GeneralModelService"
predictor_desc.endpoint_router = "WeightedRandomRender" predictor_desc.endpoint_router = "WeightedRandomRender"
predictor_desc.weighted_random_render_conf.variant_weight_list = "30" predictor_desc.weighted_random_render_conf.variant_weight_list = "100"
variant_desc = sdk.VariantConf() variant_desc = sdk.VariantConf()
variant_desc.tag = "var1" variant_desc.tag = "var1"
...@@ -105,12 +105,7 @@ class Client(object): ...@@ -105,12 +105,7 @@ class Client(object):
predictor_sdk.set_server_endpoints(endpoints) predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc() sdk_desc = predictor_sdk.gen_desc()
timestamp = time.asctime(time.localtime(time.time())) timestamp = time.asctime(time.localtime(time.time()))
predictor_path = "/tmp/" self.client_handle_.create_predictor(sdk_desc)
predictor_file = "%s_predictor.conf" % timestamp
with open(predictor_path + predictor_file, "w") as fout:
fout.write(sdk_desc)
self.client_handle_.set_predictor_conf(predictor_path, predictor_file)
self.client_handle_.create_predictor()
def get_feed_names(self): def get_feed_names(self):
return self.feed_names_ return self.feed_names_
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册