提交 93d25ba7 编写于 作者: G guru4elephant

fix predictor initialization bug

上级 76ae32df
......@@ -50,6 +50,8 @@ class PredictorClient {
void set_predictor_conf(const std::string& conf_path,
const std::string& conf_file);
int create_predictor_by_desc(const std::string & sdk_desc);
int create_predictor();
int destroy_predictor();
......
......@@ -88,7 +88,17 @@ int PredictorClient::destroy_predictor() {
_api.destroy();
}
int PredictorClient::create_predictor_by_desc(const std::string & sdk_desc) {
if (_api.create(sdk_desc) != 0) {
LOG(ERROR) << "Predictor Creation Failed";
return -1;
}
_api.thrd_initialize();
}
int PredictorClient::create_predictor() {
VLOG(2) << "Predictor path: " << _predictor_path
<< " predictor file: " << _predictor_conf;
if (_api.create(_predictor_path.c_str(), _predictor_conf.c_str()) != 0) {
LOG(ERROR) << "Predictor Creation Failed";
return -1;
......
......@@ -41,6 +41,9 @@ PYBIND11_MODULE(serving_client, m) {
const std::string &conf_file) {
self.set_predictor_conf(conf_path, conf_file);
})
.def("create_predictor_by_desc",
[](PredictorClient &self, const std::string & sdk_desc) {
self.create_predictor_by_desc(sdk_desc); })
.def("create_predictor",
[](PredictorClient &self) { self.create_predictor(); })
.def("destroy_predictor",
......
......@@ -32,6 +32,10 @@ class EndpointConfigManager {
EndpointConfigManager()
: _last_update_timestamp(0), _current_endpointmap_id(1) {}
int create(const std::string & sdk_desc_str);
int load(const std::string & sdk_desc_str);
int create(const char* path, const char* file);
int load();
......
......@@ -31,6 +31,8 @@ class PredictorApi {
int register_all();
int create(const std::string & sdk_desc_str);
int create(const char* path, const char* file);
int thrd_initialize();
......@@ -47,6 +49,11 @@ class PredictorApi {
}
Predictor* fetch_predictor(std::string ep_name) {
std::map<std::string, Endpoint*>::iterator iter;
VLOG(2) << "going to print predictor names";
for (iter = _endpoints.begin(); iter != _endpoints.end(); ++iter) {
VLOG(2) << "name: " << iter->first;
}
std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
if (it == _endpoints.end() || !it->second) {
LOG(ERROR) << "Failed fetch predictor:"
......
......@@ -26,6 +26,13 @@ namespace sdk_cpp {
using configure::SDKConf;
int EndpointConfigManager::create(const std::string& sdk_desc_str) {
if (load(sdk_desc_str) != 0) {
LOG(ERROR) << "Failed reload endpoint config";
return -1;
}
}
int EndpointConfigManager::create(const char* path, const char* file) {
_endpoint_config_path = path;
_endpoint_config_file = file;
......@@ -38,6 +45,46 @@ int EndpointConfigManager::create(const char* path, const char* file) {
return 0;
}
int EndpointConfigManager::load(const std::string& sdk_desc_str) {
try {
SDKConf sdk_conf;
sdk_conf.ParseFromString(sdk_desc_str);
VariantInfo default_var;
if (init_one_variant(sdk_conf.default_variant_conf(), default_var) != 0) {
LOG(ERROR) << "Failed read default var conf";
return -1;
}
uint32_t ep_size = sdk_conf.predictors_size();
for (uint32_t ei = 0; ei < ep_size; ++ei) {
EndpointInfo ep;
if (init_one_endpoint(sdk_conf.predictors(ei), ep, default_var) != 0) {
LOG(ERROR) << "Failed read endpoint info at: " << ei;
return -1;
}
std::map<std::string, EndpointInfo>::iterator it;
if (_ep_map.find(ep.endpoint_name) != _ep_map.end()) {
LOG(ERROR) << "Cannot insert duplicated endpoint"
<< ", ep name: " << ep.endpoint_name;
}
std::pair<std::map<std::string, EndpointInfo>::iterator, bool> r =
_ep_map.insert(std::make_pair(ep.endpoint_name, ep));
if (!r.second) {
LOG(ERROR) << "Failed insert endpoint, name" << ep.endpoint_name;
return -1;
}
}
} catch (std::exception& e) {
LOG(ERROR) << "Failed load configure" << e.what();
return -1;
}
LOG(INFO) << "Success reload endpoint config file, id: "
<< _current_endpointmap_id;
return 0;
}
int EndpointConfigManager::load() {
try {
SDKConf sdk_conf;
......
......@@ -30,6 +30,49 @@ int PredictorApi::register_all() {
return 0;
}
int PredictorApi::create(const std::string & api_desc_str) {
VLOG(2) << api_desc_str;
if (register_all() != 0) {
LOG(ERROR) << "Failed do register all!";
return -1;
}
if (_config_manager.create(api_desc_str) != 0) {
LOG(ERROR) << "Failed create config manager from desc string :"
<< api_desc_str;
return -1;
}
const std::map<std::string, EndpointInfo>& map = _config_manager.config();
std::map<std::string, EndpointInfo>::const_iterator it;
for (it = map.begin(); it != map.end(); ++it) {
const EndpointInfo& ep_info = it->second;
Endpoint* ep = new (std::nothrow) Endpoint();
if (ep->initialize(ep_info) != 0) {
LOG(ERROR) << "Failed intialize endpoint:" << ep_info.endpoint_name;
return -1;
}
if (_endpoints.find(ep_info.endpoint_name) != _endpoints.end()) {
LOG(ERROR) << "Cannot insert duplicated endpoint:"
<< ep_info.endpoint_name;
return -1;
}
std::pair<std::map<std::string, Endpoint*>::iterator, bool> r =
_endpoints.insert(std::make_pair(ep_info.endpoint_name, ep));
if (!r.second) {
LOG(ERROR) << "Failed insert endpoint:" << ep_info.endpoint_name;
return -1;
}
LOG(INFO) << "Succ create endpoint instance with name: "
<< ep_info.endpoint_name;
}
return 0;
}
int PredictorApi::create(const char* path, const char* file) {
if (register_all() != 0) {
LOG(ERROR) << "Failed do register all!";
......
......@@ -63,6 +63,6 @@ if __name__ == '__main__':
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource)
print(result[-1])
print("{}\t{}".format(sys.argv[3], sum(result[-1]) / len(result[-1])))
print("{}\t{}".format(sys.argv[3], sum(result[2]) / 1000.0 / 1000.0 / len(result[2])))
......@@ -36,7 +36,7 @@ class SDKConfig(object):
predictor_desc.service_name = \
"baidu.paddle_serving.predictor.general_model.GeneralModelService"
predictor_desc.endpoint_router = "WeightedRandomRender"
predictor_desc.weighted_random_render_conf.variant_weight_list = "30"
predictor_desc.weighted_random_render_conf.variant_weight_list = "100"
variant_desc = sdk.VariantConf()
variant_desc.tag = "var1"
......@@ -105,12 +105,7 @@ class Client(object):
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
timestamp = time.asctime(time.localtime(time.time()))
predictor_path = "/tmp/"
predictor_file = "%s_predictor.conf" % timestamp
with open(predictor_path + predictor_file, "w") as fout:
fout.write(sdk_desc)
self.client_handle_.set_predictor_conf(predictor_path, predictor_file)
self.client_handle_.create_predictor()
self.client_handle_.create_predictor(sdk_desc)
def get_feed_names(self):
return self.feed_names_
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册