提交 c05d02c6 编写于 作者: G guru4elephant

fix general reader infer batch size problem

上级 5831f6f4
......@@ -36,8 +36,9 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralInferOp::inference() {
VLOG(2) << "Going to run inference";
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
VLOG(2) << "Get precedent op name: " << pre_name();
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!input_blob) {
......@@ -48,6 +49,8 @@ int GeneralInferOp::inference() {
const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->GetBatchSize();
VLOG(2) << "input batch size: " << batch_size;
output_blob->SetBatchSize(batch_size);
VLOG(2) << "infer batch size: " << batch_size;
......
......@@ -45,7 +45,9 @@ int GeneralTextReaderOp::inference() {
std::vector<int64_t> capacity;
GeneralBlob *res = mutable_data<GeneralBlob>();
TensorVector *in = &res->tensor_vector;
TensorVector *out = &res->tensor_vector;
res->SetBatchSize(batch_size);
if (!res) {
LOG(ERROR) << "Failed get op tls reader object output";
......@@ -103,23 +105,23 @@ int GeneralTextReaderOp::inference() {
VLOG(2) << "var[" << i << "] is tensor, capacity: " << capacity[i];
}
lod_tensor.name = model_config->_feed_name[i];
in->push_back(lod_tensor);
out->push_back(lod_tensor);
}
for (int i = 0; i < var_num; ++i) {
if (in->at(i).lod.size() == 1) {
if (out->at(i).lod.size() == 1) {
for (int j = 0; j < batch_size; ++j) {
const Tensor &tensor = req->insts(j).tensor_array(i);
int data_len = tensor.int_data_size();
int cur_len = in->at(i).lod[0].back();
in->at(i).lod[0].push_back(cur_len + data_len);
int cur_len = out->at(i).lod[0].back();
out->at(i).lod[0].push_back(cur_len + data_len);
}
in->at(i).data.Resize(in->at(i).lod[0].back() * elem_size[i]);
in->at(i).shape = {in->at(i).lod[0].back(), 1};
out->at(i).data.Resize(out->at(i).lod[0].back() * elem_size[i]);
out->at(i).shape = {out->at(i).lod[0].back(), 1};
VLOG(2) << "var[" << i
<< "] is lod_tensor and len=" << in->at(i).lod[0].back();
<< "] is lod_tensor and len=" << out->at(i).lod[0].back();
} else {
in->at(i).data.Resize(batch_size * capacity[i] * elem_size[i]);
out->at(i).data.Resize(batch_size * capacity[i] * elem_size[i]);
VLOG(2) << "var[" << i
<< "] is tensor and capacity=" << batch_size * capacity[i];
}
......@@ -127,7 +129,7 @@ int GeneralTextReaderOp::inference() {
for (int i = 0; i < var_num; ++i) {
if (elem_type[i] == 0) {
int64_t *dst_ptr = static_cast<int64_t *>(in->at(i).data.data());
int64_t *dst_ptr = static_cast<int64_t *>(out->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0;
......@@ -136,14 +138,14 @@ int GeneralTextReaderOp::inference() {
dst_ptr[offset + k] =
req->insts(j).tensor_array(i).int_data(k);
}
if (in->at(i).lod.size() == 1) {
offset = in->at(i).lod[0][j + 1];
if (out->at(i).lod.size() == 1) {
offset = out->at(i).lod[0][j + 1];
} else {
offset += capacity[i];
}
}
} else {
float *dst_ptr = static_cast<float *>(in->at(i).data.data());
float *dst_ptr = static_cast<float *>(out->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0;
......@@ -152,8 +154,8 @@ int GeneralTextReaderOp::inference() {
dst_ptr[offset + k] =
req->insts(j).tensor_array(i).int_data(k);
}
if (in->at(i).lod.size() == 1) {
offset = in->at(i).lod[0][j + 1];
if (out->at(i).lod.size() == 1) {
offset = out->at(i).lod[0][j + 1];
} else {
offset += capacity[i];
}
......@@ -162,6 +164,7 @@ int GeneralTextReaderOp::inference() {
}
int64_t end = timeline.TimeStampUS();
res->p_size = 0;
AddBlobInfo(res, start);
AddBlobInfo(res, end);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册