提交 1ad96012 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #177 from MRXLT/general-server-v1

bug fix && add debug for batch_predict
......@@ -214,7 +214,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
if (fetch_name.size() == 0) {
return fetch_result_batch;
}
fetch_result_batch.resize(batch_size);
fetch_result_batch.resize(batch_size + 1);
int fetch_name_num = fetch_name.size();
for (int bi = 0; bi < batch_size; bi++) {
fetch_result_batch[bi].resize(fetch_name_num);
......@@ -226,6 +226,9 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size();
Request req;
for (auto & name : fetch_name) {
req.add_fetch_var_names(name);
}
//
for (int bi = 0; bi < batch_size; bi++) {
VLOG(2) << "prepare batch " << bi;
......@@ -240,7 +243,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
for (auto &name : int_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name"
<< "prepared";
int vec_idx = 0;
......@@ -305,6 +308,10 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
}
}
}
//last index for infer time
fetch_result_batch[batch_size].resize(1);
fetch_result_batch[batch_size][0].resize(1);
fetch_result_batch[batch_size][0][0] = res.mean_infer_us();
}
return fetch_result_batch;
......
......@@ -147,7 +147,7 @@ class Client(object):
return result_map
def batch_predict(self, feed_batch=[], fetch=[]):
def batch_predict(self, feed_batch=[], fetch=[], debug=False):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
......@@ -181,13 +181,18 @@ class Client(object):
fetch_names)
result_map_batch = []
for result in result_batch:
for result in result_batch[:-1]:
result_map = {}
for i, name in enumerate(fetch_names):
result_map[name] = result[i]
result_map_batch.append(result_map)
return result_map_batch
infer_time = result_batch[-1][0][0]
if debug:
return result_map_batch, infer_time
else:
return result_map_batch
def release(self):
self.client_handle_.destroy_predictor()
......@@ -17,7 +17,7 @@ from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import tarfile
import paddle_serving_server as paddle_serving_server
import paddle_serving_server_gpu as paddle_serving_server
from version import serving_server_version
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册