提交 d44faecb 编写于 作者: M MRXLT

fix batch size for lod_tensor && add debug for batch_predict

上级 ce071fdf
...@@ -214,7 +214,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -214,7 +214,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
if (fetch_name.size() == 0) { if (fetch_name.size() == 0) {
return fetch_result_batch; return fetch_result_batch;
} }
fetch_result_batch.resize(batch_size); fetch_result_batch.resize(batch_size + 1);
int fetch_name_num = fetch_name.size(); int fetch_name_num = fetch_name.size();
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
fetch_result_batch[bi].resize(fetch_name_num); fetch_result_batch[bi].resize(fetch_name_num);
...@@ -226,6 +226,9 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -226,6 +226,9 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
VLOG(2) << "float feed name size: " << float_feed_name.size(); VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size(); VLOG(2) << "int feed name size: " << int_feed_name.size();
Request req; Request req;
for (auto & name : fetch_name) {
req.add_fetch_var_names(name);
}
// //
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
VLOG(2) << "prepare batch " << bi; VLOG(2) << "prepare batch " << bi;
...@@ -240,7 +243,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -240,7 +243,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
for (auto &name : int_feed_name) { for (auto &name : int_feed_name) {
tensor_vec.push_back(inst->add_tensor_array()); tensor_vec.push_back(inst->add_tensor_array());
} }
VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name" VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name"
<< "prepared"; << "prepared";
int vec_idx = 0; int vec_idx = 0;
...@@ -305,6 +308,10 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -305,6 +308,10 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
} }
} }
} }
//last index for infer time
fetch_result_batch[batch_size].resize(1);
fetch_result_batch[batch_size][0].resize(1);
fetch_result_batch[batch_size][0][0] = res.mean_infer_us();
} }
return fetch_result_batch; return fetch_result_batch;
......
...@@ -53,11 +53,18 @@ int GeneralInferOp::inference() { ...@@ -53,11 +53,18 @@ int GeneralInferOp::inference() {
const TensorVector *in = &reader_out->tensor_vector; const TensorVector *in = &reader_out->tensor_vector;
TensorVector *out = butil::get_object<TensorVector>(); TensorVector *out = butil::get_object<TensorVector>();
int batch_size = (*in)[0].shape[0]; int batch_size = 0;
if ((*in)[0].lod.size() == 1) {
batch_size = (*in)[0].lod[0].size() - 1;
}
else {
batch_size = (*in)[0].shape[0];
}
// infer // infer
Timer timeline; Timer timeline;
double infer_time = 0.0; double infer_time = 0.0;
timeline.Start(); timeline.Start();
VLOG(2) << "batch size : " << batch_size;
if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) { if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME; LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME;
return -1; return -1;
...@@ -70,7 +77,7 @@ int GeneralInferOp::inference() { ...@@ -70,7 +77,7 @@ int GeneralInferOp::inference() {
VLOG(2) << "start to call load general model_conf op"; VLOG(2) << "start to call load general model_conf op";
baidu::paddle_serving::predictor::Resource &resource = baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance(); baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "get resource pointer done."; VLOG(2) << "get resource pointer done.";
std::shared_ptr<PaddleGeneralModelConfig> model_config = std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config(); resource.get_general_model_config();
...@@ -81,7 +88,7 @@ int GeneralInferOp::inference() { ...@@ -81,7 +88,7 @@ int GeneralInferOp::inference() {
fetch_index[i] = fetch_index[i] =
model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)]; model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
} }
// response inst with only fetch_var_names // response inst with only fetch_var_names
Response *res = mutable_data<Response>(); Response *res = mutable_data<Response>();
...@@ -94,7 +101,7 @@ int GeneralInferOp::inference() { ...@@ -94,7 +101,7 @@ int GeneralInferOp::inference() {
// currently only response float tensor or lod_tensor // currently only response float tensor or lod_tensor
tensor->set_elem_type(1); tensor->set_elem_type(1);
if (model_config->_is_lod_fetch[idx]) { if (model_config->_is_lod_fetch[idx]) {
VLOG(2) << "out[" << idx << " is lod_tensor"; VLOG(2) << "out[" << idx << "] is lod_tensor";
tensor->add_shape(-1); tensor->add_shape(-1);
} else { } else {
VLOG(2) << "out[" << idx << "] is tensor"; VLOG(2) << "out[" << idx << "] is tensor";
......
...@@ -147,7 +147,7 @@ class Client(object): ...@@ -147,7 +147,7 @@ class Client(object):
return result_map return result_map
def batch_predict(self, feed_batch=[], fetch=[]): def batch_predict(self, feed_batch=[], fetch=[], debug=False):
int_slot_batch = [] int_slot_batch = []
float_slot_batch = [] float_slot_batch = []
int_feed_names = [] int_feed_names = []
...@@ -181,13 +181,18 @@ class Client(object): ...@@ -181,13 +181,18 @@ class Client(object):
fetch_names) fetch_names)
result_map_batch = [] result_map_batch = []
for result in result_batch: for result in result_batch[:-1]:
result_map = {} result_map = {}
for i, name in enumerate(fetch_names): for i, name in enumerate(fetch_names):
result_map[name] = result[i] result_map[name] = result[i]
result_map_batch.append(result_map) result_map_batch.append(result_map)
return result_map_batch infer_time = result_batch[-1][0][0]
if debug:
return result_map_batch, infer_time
else:
return result_map_batch
def release(self): def release(self):
self.client_handle_.destroy_predictor() self.client_handle_.destroy_predictor()
...@@ -17,7 +17,7 @@ from .proto import server_configure_pb2 as server_sdk ...@@ -17,7 +17,7 @@ from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format import google.protobuf.text_format
import tarfile import tarfile
import paddle_serving_server as paddle_serving_server import paddle_serving_server_gpu as paddle_serving_server
from version import serving_server_version from version import serving_server_version
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册