提交 d44faecb 编写于 作者: M MRXLT

fix batch size for lod_tensor && add debug for batch_predict

上级 ce071fdf
......@@ -214,7 +214,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
if (fetch_name.size() == 0) {
return fetch_result_batch;
}
fetch_result_batch.resize(batch_size);
fetch_result_batch.resize(batch_size + 1);
int fetch_name_num = fetch_name.size();
for (int bi = 0; bi < batch_size; bi++) {
fetch_result_batch[bi].resize(fetch_name_num);
......@@ -226,6 +226,9 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size();
Request req;
for (auto & name : fetch_name) {
req.add_fetch_var_names(name);
}
//
for (int bi = 0; bi < batch_size; bi++) {
VLOG(2) << "prepare batch " << bi;
......@@ -240,7 +243,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
for (auto &name : int_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name"
<< "prepared";
int vec_idx = 0;
......@@ -305,6 +308,10 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
}
}
}
//last index for infer time
fetch_result_batch[batch_size].resize(1);
fetch_result_batch[batch_size][0].resize(1);
fetch_result_batch[batch_size][0][0] = res.mean_infer_us();
}
return fetch_result_batch;
......
......@@ -53,11 +53,18 @@ int GeneralInferOp::inference() {
const TensorVector *in = &reader_out->tensor_vector;
TensorVector *out = butil::get_object<TensorVector>();
int batch_size = (*in)[0].shape[0];
int batch_size = 0;
if ((*in)[0].lod.size() == 1) {
batch_size = (*in)[0].lod[0].size() - 1;
}
else {
batch_size = (*in)[0].shape[0];
}
// infer
Timer timeline;
double infer_time = 0.0;
timeline.Start();
VLOG(2) << "batch size : " << batch_size;
if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME;
return -1;
......@@ -70,7 +77,7 @@ int GeneralInferOp::inference() {
VLOG(2) << "start to call load general model_conf op";
baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "get resource pointer done.";
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config();
......@@ -81,7 +88,7 @@ int GeneralInferOp::inference() {
fetch_index[i] =
model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
}
// response inst with only fetch_var_names
Response *res = mutable_data<Response>();
......@@ -94,7 +101,7 @@ int GeneralInferOp::inference() {
// currently only response float tensor or lod_tensor
tensor->set_elem_type(1);
if (model_config->_is_lod_fetch[idx]) {
VLOG(2) << "out[" << idx << " is lod_tensor";
VLOG(2) << "out[" << idx << "] is lod_tensor";
tensor->add_shape(-1);
} else {
VLOG(2) << "out[" << idx << "] is tensor";
......
......@@ -147,7 +147,7 @@ class Client(object):
return result_map
def batch_predict(self, feed_batch=[], fetch=[]):
def batch_predict(self, feed_batch=[], fetch=[], debug=False):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
......@@ -181,13 +181,18 @@ class Client(object):
fetch_names)
result_map_batch = []
for result in result_batch:
for result in result_batch[:-1]:
result_map = {}
for i, name in enumerate(fetch_names):
result_map[name] = result[i]
result_map_batch.append(result_map)
return result_map_batch
infer_time = result_batch[-1][0][0]
if debug:
return result_map_batch, infer_time
else:
return result_map_batch
def release(self):
self.client_handle_.destroy_predictor()
......@@ -17,7 +17,7 @@ from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import tarfile
import paddle_serving_server as paddle_serving_server
import paddle_serving_server_gpu as paddle_serving_server
from version import serving_server_version
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册