提交 d00169c0 编写于 作者: M MRXLT

fix response op

上级 ef2ee57a
......@@ -168,7 +168,7 @@ int GeneralReaderOp::inference() {
int cur_len = out->at(i).lod[0].back();
VLOG(2) << "current len: " << cur_len;
int sample_len;
int sample_len = 0;
if (tensor.shape_size() == 1) {
sample_len = data_len;
} else {
......
......@@ -15,8 +15,10 @@
#include "core/general-server/op/general_response_op.h"
#include <algorithm>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <utility>
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
......@@ -86,37 +88,51 @@ int GeneralResponseOp::inference() {
// To get the order of model return values
output->set_engine_name(pre_name);
FetchInst *fetch_inst = output->add_insts();
std::map<std::string, int> fetch_index_map;
for (int i = 0; i < in->size(); ++i) {
VLOG(2) << "index " << i << " var " << in->at(i).name;
fetch_index_map.insert(std::pair<std::string, int>(in->at(i).name, i));
}
for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array();
tensor->set_elem_type(1);
int true_idx = fetch_index_map[model_config->_fetch_name[idx]];
if (model_config->_is_lod_fetch[idx]) {
VLOG(2) << "out[" << idx << "] is lod_tensor";
for (int k = 0; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "out[" << idx << "] " << model_config->_fetch_name[idx]
<< " is lod_tensor";
for (int k = 0; k < in->at(true_idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k << "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
tensor->add_shape(in->at(true_idx).shape[k]);
}
} else {
VLOG(2) << "out[" << idx << "] is tensor";
for (int k = 0; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k << "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
VLOG(2) << "out[" << idx << "] " << model_config->_fetch_name[idx]
<< " is tensor";
for (int k = 0; k < in->at(true_idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k << "]: " << in->at(true_idx).shape[k];
tensor->add_shape(in->at(true_idx).shape[k]);
}
}
}
int var_idx = 0;
for (auto &idx : fetch_index) {
int true_idx = fetch_index_map[model_config->_fetch_name[idx]];
int cap = 1;
for (int j = 0; j < in->at(idx).shape.size(); ++j) {
cap *= in->at(idx).shape[j];
for (int j = 0; j < in->at(true_idx).shape.size(); ++j) {
cap *= in->at(true_idx).shape[j];
}
if (in->at(idx).dtype == paddle::PaddleDType::INT64) {
int64_t *data_ptr = static_cast<int64_t *>(in->at(idx).data.data());
if (in->at(true_idx).dtype == paddle::PaddleDType::INT64) {
VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx]
<< "].";
int64_t *data_ptr =
static_cast<int64_t *>(in->at(true_idx).data.data());
if (model_config->_is_lod_fetch[idx]) {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
for (int j = 0; j < in->at(true_idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(idx).lod[0][j]);
in->at(true_idx).lod[0][j]);
}
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]);
......@@ -127,14 +143,17 @@ int GeneralResponseOp::inference() {
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]);
}
}
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
} else if (in->at(idx).dtype == paddle::PaddleDType::FLOAT32) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
} else if (in->at(true_idx).dtype == paddle::PaddleDType::FLOAT32) {
VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx]
<< "].";
float *data_ptr = static_cast<float *>(in->at(true_idx).data.data());
if (model_config->_is_lod_fetch[idx]) {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
for (int j = 0; j < in->at(true_idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(idx).lod[0][j]);
in->at(true_idx).lod[0][j]);
}
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]);
......@@ -145,6 +164,7 @@ int GeneralResponseOp::inference() {
fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]);
}
}
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册