提交 8a82c479 编写于 作者: M MRXLT

timeline tool support multi process

上级 e6722e23
......@@ -45,12 +45,12 @@ class PredictorRes {
~PredictorRes() {}
public:
const std::vector<std::vector<int64_t>> & get_int64_by_name(
const std::string & name) {
const std::vector<std::vector<int64_t>>& get_int64_by_name(
const std::string& name) {
return _int64_map[name];
}
const std::vector<std::vector<float>> & get_float_by_name(
const std::string & name) {
const std::vector<std::vector<float>>& get_float_by_name(
const std::string& name) {
return _float_map[name];
}
......@@ -71,7 +71,7 @@ class PredictorClient {
void set_predictor_conf(const std::string& conf_path,
const std::string& conf_file);
int create_predictor_by_desc(const std::string & sdk_desc);
int create_predictor_by_desc(const std::string& sdk_desc);
int create_predictor();
int destroy_predictor();
......@@ -81,7 +81,8 @@ class PredictorClient {
const std::vector<std::vector<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name,
PredictorRes & predict_res); // NOLINT
PredictorRes& predict_res, // NOLINT
const int& pid);
std::vector<std::vector<float>> predict(
const std::vector<std::vector<float>>& float_feed,
......
......@@ -137,7 +137,8 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name,
PredictorRes &predict_res) { // NOLINT
PredictorRes &predict_res,
const int &pid) { // NOLINT
predict_res._int64_map.clear();
predict_res._float_map.clear();
Timer timeline;
......@@ -241,6 +242,7 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
if (FLAGS_profile_client) {
std::ostringstream oss;
oss << "PROFILE\t"
<< "pid:" << pid << "\t"
<< "prepro_0:" << preprocess_start << " "
<< "prepro_1:" << preprocess_end << " "
<< "client_infer_0:" << client_infer_start << " "
......
......@@ -31,13 +31,15 @@ PYBIND11_MODULE(serving_client, m) {
py::class_<PredictorRes>(m, "PredictorRes", py::buffer_protocol())
.def(py::init())
.def("get_int64_by_name",
[](PredictorRes &self, std::string & name) {
[](PredictorRes &self, std::string &name) {
return self.get_int64_by_name(name);
}, py::return_value_policy::reference)
},
py::return_value_policy::reference)
.def("get_float_by_name",
[](PredictorRes &self, std::string & name) {
[](PredictorRes &self, std::string &name) {
return self.get_float_by_name(name);
}, py::return_value_policy::reference);
},
py::return_value_policy::reference);
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init())
......@@ -56,26 +58,29 @@ PYBIND11_MODULE(serving_client, m) {
self.set_predictor_conf(conf_path, conf_file);
})
.def("create_predictor_by_desc",
[](PredictorClient &self, const std::string & sdk_desc) {
self.create_predictor_by_desc(sdk_desc); })
[](PredictorClient &self, const std::string &sdk_desc) {
self.create_predictor_by_desc(sdk_desc);
})
.def("create_predictor",
[](PredictorClient &self) { self.create_predictor(); })
.def("destroy_predictor",
[](PredictorClient &self) { self.destroy_predictor(); })
.def("predict",
[](PredictorClient &self,
const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name,
PredictorRes & predict_res) {
const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name,
PredictorRes &predict_res,
const int &pid) {
return self.predict(float_feed,
float_feed_name,
int_feed,
int_feed_name,
fetch_name,
predict_res);
predict_res,
pid);
})
.def("batch_predict",
[](PredictorClient &self,
......
......@@ -36,6 +36,7 @@ class BertService():
self.show_ids = show_ids
self.do_lower_case = do_lower_case
self.retry = retry
self.pid = os.getpid()
self.profile = True if ("FLAGS_profile_client" in os.environ and
os.environ["FLAGS_profile_client"]) else False
......@@ -78,7 +79,8 @@ class BertService():
}
prepro_end = time.time()
if self.profile:
print("PROFILE\tbert_pre_0:{} bert_pre_1:{}".format(
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map = self.client.predict(feed=feed, fetch=fetch)
......@@ -111,7 +113,8 @@ class BertService():
feed_batch.append(feed)
prepro_end = time.time()
if self.profile:
print("PROFILE\tbert_pre_0:{} bert_pre_1:{}".format(
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map_batch = self.client.batch_predict(
......
......@@ -5,8 +5,9 @@ import sys
profile_file = sys.argv[1]
def prase(line, counter):
event_list = line.split(" ")
def prase(pid_str, time_str, counter):
pid = pid_str.split(":")[1]
event_list = time_str.split(" ")
trace_list = []
for event in event_list:
name, ts = event.split(":")
......@@ -19,7 +20,7 @@ def prase(line, counter):
event_dict = {}
event_dict["name"] = name
event_dict["tid"] = 0
event_dict["pid"] = 0
event_dict["pid"] = pid
event_dict["ts"] = ts
event_dict["ph"] = ph
......@@ -36,7 +37,7 @@ if __name__ == "__main__":
for line in f.readlines():
line = line.strip().split("\t")
if line[0] == "PROFILE":
trace_list = prase(line[1], counter)
trace_list = prase(line[1], line[2], counter)
counter += 1
for trace in trace_list:
all_list.append(trace)
......
......@@ -78,6 +78,7 @@ class Client(object):
self.feed_types_ = {}
self.feed_names_to_idx_ = {}
self.rpath()
self.pid = os.getpid()
def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__)
......@@ -85,7 +86,6 @@ class Client(object):
lib_path = os.path.join(lib_path, 'lib')
os.popen('patchelf --set-rpath {} {}'.format(lib_path, client_path))
def load_client_config(self, path):
from .serving_client import PredictorClient
from .serving_client import PredictorRes
......@@ -128,9 +128,8 @@ class Client(object):
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
print(sdk_desc)
self.client_handle_.create_predictor_by_desc(
sdk_desc.SerializeToString())
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
))
def get_feed_names(self):
return self.feed_names_
......@@ -144,6 +143,7 @@ class Client(object):
int_feed_names = []
float_feed_names = []
fetch_names = []
for key in feed:
if key not in self.feed_names_:
continue
......@@ -158,16 +158,18 @@ class Client(object):
if key in self.fetch_names_:
fetch_names.append(key)
ret = self.client_handle_.predict(
float_slot, float_feed_names, int_slot,
int_feed_names, fetch_names, self.result_handle_)
ret = self.client_handle_.predict(float_slot, float_feed_names,
int_slot, int_feed_names, fetch_names,
self.result_handle_, self.pid)
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = self.result_handle_.get_int64_by_name(name)[0]
result_map[name] = self.result_handle_.get_int64_by_name(name)[
0]
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = self.result_handle_.get_float_by_name(name)[0]
result_map[name] = self.result_handle_.get_float_by_name(name)[
0]
return result_map
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册