提交 43c55732 编写于 作者: M MRXLT

fix general model op batch predict bug

上级 9c3c3eb9
......@@ -14,8 +14,8 @@
#include "examples/demo-serving/op/general_model_op.h"
#include <algorithm>
#include <sstream>
#include <iostream>
#include <sstream>
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
......@@ -69,9 +69,7 @@ int GeneralModelOp::inference() {
} else {
lod_tensor.shape.push_back(batch_size);
capacity[i] = 1;
for (int k = 0;
k < req->insts(0).tensor_array(i).shape_size();
++k) {
for (int k = 0; k < req->insts(0).tensor_array(i).shape_size(); ++k) {
int dim = req->insts(0).tensor_array(i).shape(k);
VLOG(3) << "shape for var[" << i << "]: " << dim;
capacity[i] *= dim;
......@@ -90,10 +88,9 @@ int GeneralModelOp::inference() {
for (int i = 0; i < var_num; ++i) {
if (in->at(i).lod.size() == 1) {
for (int j = 0; j < batch_size; ++j) {
const Tensor & tensor = req->insts(j).tensor_array(i);
const Tensor &tensor = req->insts(j).tensor_array(i);
int data_len = tensor.data_size();
VLOG(3) << "tensor size for var[" << i << "]: "
<< tensor.data_size();
VLOG(3) << "tensor size for var[" << i << "]: " << tensor.data_size();
int cur_len = in->at(i).lod[0].back();
VLOG(3) << "current len: " << cur_len;
in->at(i).lod[0].push_back(cur_len + data_len);
......@@ -101,18 +98,18 @@ int GeneralModelOp::inference() {
}
in->at(i).data.Resize(in->at(i).lod[0].back() * elem_size[i]);
in->at(i).shape = {in->at(i).lod[0].back(), 1};
VLOG(3) << "var[" << i << "] is lod_tensor and len="
<< in->at(i).lod[0].back();
VLOG(3) << "var[" << i
<< "] is lod_tensor and len=" << in->at(i).lod[0].back();
} else {
in->at(i).data.Resize(batch_size * capacity[i] * elem_size[i]);
VLOG(3) << "var[" << i << "] is tensor and capacity="
<< batch_size * capacity[i];
VLOG(3) << "var[" << i
<< "] is tensor and capacity=" << batch_size * capacity[i];
}
}
for (int i = 0; i < var_num; ++i) {
if (elem_type[i] == 0) {
int64_t * dst_ptr = static_cast<int64_t *>(in->at(i).data.data());
int64_t *dst_ptr = static_cast<int64_t *>(in->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0; k < req->insts(j).tensor_array(i).data_size(); ++k) {
......@@ -126,7 +123,7 @@ int GeneralModelOp::inference() {
}
}
} else {
float * dst_ptr = static_cast<float *>(in->at(i).data.data());
float *dst_ptr = static_cast<float *>(in->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0; k < req->insts(j).tensor_array(i).data_size(); ++k) {
......@@ -151,17 +148,16 @@ int GeneralModelOp::inference() {
if (predictor::InferManager::instance().infer(
GENERAL_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: "
<< GENERAL_MODEL_NAME;
LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME;
return -1;
}
Response * res = mutable_data<Response>();
Response *res = mutable_data<Response>();
for (int i = 0; i < batch_size; ++i) {
FetchInst * fetch_inst = res->add_insts();
FetchInst *fetch_inst = res->add_insts();
for (int j = 0; j < out->size(); ++j) {
Tensor * tensor = fetch_inst->add_tensor_array();
Tensor *tensor = fetch_inst->add_tensor_array();
tensor->set_elem_type(1);
if (out->at(j).lod.size() == 1) {
tensor->add_shape(-1);
......@@ -174,30 +170,29 @@ int GeneralModelOp::inference() {
}
for (int i = 0; i < out->size(); ++i) {
float * data_ptr = static_cast<float *>(out->at(i).data.data());
float *data_ptr = static_cast<float *>(out->at(i).data.data());
int cap = 1;
for (int j = 0; j < out->at(i).shape.size(); ++j) {
for (int j = 1; j < out->at(i).shape.size(); ++j) {
cap *= out->at(i).shape[j];
}
if (out->at(i).lod.size() == 1) {
for (int j = 0; j < batch_size; ++j) {
for (int k = out->at(i).lod[0][j];
k < out->at(i).lod[0][j + 1];
for (int k = out->at(i).lod[0][j]; k < out->at(i).lod[0][j + 1];
k++) {
res->mutable_insts(j)->mutable_tensor_array(i)->add_data(
(char *)(&(data_ptr[k])), sizeof(float));
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
}
}
} else {
for (int j = 0; j < batch_size; ++j) {
for (int k = j * cap; k < (j + 1) * cap; ++k) {
res->mutable_insts(j)->mutable_tensor_array(i)->add_data(
(char *)(&(data_ptr[k])), sizeof(float));
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
}
}
}
}
for (size_t i = 0; i < in->size(); ++i) {
(*in)[i].shape.clear();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册