提交 6324cfe9 编写于 作者: M MRXLT

add timeline for batch_predict

上级 77dd8f88
......@@ -145,7 +145,7 @@ std::vector<std::vector<float>> PredictorClient::predict(
int64_t preprocess_start = timeline.TimeStampUS();
// we save infer_us at fetch_result[fetch_name.size()]
fetch_result.resize(fetch_name.size() + 1);
fetch_result.resize(fetch_name.size());
_api.thrd_clear();
_predictor = _api.fetch_predictor("general_model");
......@@ -276,7 +276,11 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
if (fetch_name.size() == 0) {
return fetch_result_batch;
}
fetch_result_batch.resize(batch_size + 1);
Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS();
fetch_result_batch.resize(batch_size);
int fetch_name_num = fetch_name.size();
for (int bi = 0; bi < batch_size; bi++) {
fetch_result_batch[bi].resize(fetch_name_num);
......@@ -349,13 +353,30 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
<< "itn feed value prepared";
}
int64_t preprocess_end = timeline.TimeStampUS();
int64_t client_infer_start = timeline.TimeStampUS();
Response res;
int64_t client_infer_end = 0;
int64_t postprocess_start = 0;
int64_t postprocess_end = 0;
if (FLAGS_profile_client) {
if (FLAGS_profile_server) {
req.set_profile_server(true);
}
}
res.Clear();
if (_predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString();
exit(-1);
} else {
client_infer_end = timeline.TimeStampUS();
postprocess_start = client_infer_end;
for (int bi = 0; bi < batch_size; bi++) {
for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name];
......@@ -372,8 +393,30 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
}
}
}
postprocess_end = timeline.TimeStampUS();
}
if (FLAGS_profile_client) {
std::ostringstream oss;
oss << "PROFILE\t"
<< "prepro_0:" << preprocess_start << " "
<< "prepro_1:" << preprocess_end << " "
<< "client_infer_0:" << client_infer_start << " "
<< "client_infer_1:" << client_infer_end << " ";
if (FLAGS_profile_server) {
int op_num = res.profile_time_size() / 2;
for (int i = 0; i < op_num; ++i) {
oss << "op" << i << "_0:" << res.profile_time(i * 2) << " ";
oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " ";
}
}
oss << "postpro_0:" << postprocess_start << " ";
oss << "postpro_1:" << postprocess_end;
fprintf(stderr, "%s\n", oss.str().c_str());
}
return fetch_result_batch;
}
......
......@@ -89,8 +89,8 @@ class Client(object):
self.client_handle_ = PredictorClient()
self.client_handle_.init(path)
read_env_flags = ["profile_client", "profile_server"]
self.client_handle_.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
self.client_handle_.init_gflags([sys.argv[
0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_shapes_ = [var.shape for var in model_conf.feed_var]
......@@ -183,18 +183,13 @@ class Client(object):
fetch_names)
result_map_batch = []
for result in result_batch[:-1]:
for result in result_batch:
result_map = {}
for i, name in enumerate(fetch_names):
result_map[name] = result[i]
result_map_batch.append(result_map)
infer_time = result_batch[-1][0][0]
if profile:
return result_map_batch, infer_time
else:
return result_map_batch
return result_map_batch
def release(self):
self.client_handle_.destroy_predictor()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册