提交 1ad96012 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #177 from MRXLT/general-server-v1

bug fix && add debug for batch_predict
...@@ -214,7 +214,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -214,7 +214,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
if (fetch_name.size() == 0) { if (fetch_name.size() == 0) {
return fetch_result_batch; return fetch_result_batch;
} }
fetch_result_batch.resize(batch_size); fetch_result_batch.resize(batch_size + 1);
int fetch_name_num = fetch_name.size(); int fetch_name_num = fetch_name.size();
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
fetch_result_batch[bi].resize(fetch_name_num); fetch_result_batch[bi].resize(fetch_name_num);
...@@ -226,6 +226,9 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -226,6 +226,9 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
VLOG(2) << "float feed name size: " << float_feed_name.size(); VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size(); VLOG(2) << "int feed name size: " << int_feed_name.size();
Request req; Request req;
for (auto & name : fetch_name) {
req.add_fetch_var_names(name);
}
// //
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
VLOG(2) << "prepare batch " << bi; VLOG(2) << "prepare batch " << bi;
...@@ -240,7 +243,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -240,7 +243,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
for (auto &name : int_feed_name) { for (auto &name : int_feed_name) {
tensor_vec.push_back(inst->add_tensor_array()); tensor_vec.push_back(inst->add_tensor_array());
} }
VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name" VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name"
<< "prepared"; << "prepared";
int vec_idx = 0; int vec_idx = 0;
...@@ -305,6 +308,10 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -305,6 +308,10 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
} }
} }
} }
//last index for infer time
fetch_result_batch[batch_size].resize(1);
fetch_result_batch[batch_size][0].resize(1);
fetch_result_batch[batch_size][0][0] = res.mean_infer_us();
} }
return fetch_result_batch; return fetch_result_batch;
......
...@@ -147,7 +147,7 @@ class Client(object): ...@@ -147,7 +147,7 @@ class Client(object):
return result_map return result_map
def batch_predict(self, feed_batch=[], fetch=[]): def batch_predict(self, feed_batch=[], fetch=[], debug=False):
int_slot_batch = [] int_slot_batch = []
float_slot_batch = [] float_slot_batch = []
int_feed_names = [] int_feed_names = []
...@@ -181,13 +181,18 @@ class Client(object): ...@@ -181,13 +181,18 @@ class Client(object):
fetch_names) fetch_names)
result_map_batch = [] result_map_batch = []
for result in result_batch: for result in result_batch[:-1]:
result_map = {} result_map = {}
for i, name in enumerate(fetch_names): for i, name in enumerate(fetch_names):
result_map[name] = result[i] result_map[name] = result[i]
result_map_batch.append(result_map) result_map_batch.append(result_map)
return result_map_batch infer_time = result_batch[-1][0][0]
if debug:
return result_map_batch, infer_time
else:
return result_map_batch
def release(self): def release(self):
self.client_handle_.destroy_predictor() self.client_handle_.destroy_predictor()
...@@ -17,7 +17,7 @@ from .proto import server_configure_pb2 as server_sdk ...@@ -17,7 +17,7 @@ from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format import google.protobuf.text_format
import tarfile import tarfile
import paddle_serving_server as paddle_serving_server import paddle_serving_server_gpu as paddle_serving_server
from version import serving_server_version from version import serving_server_version
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册