提交 70466ca6 编写于 作者: M MRXLT

Merge remote-tracking branch 'upstream/develop' into 0.2.2-web-service

sync
...@@ -89,50 +89,41 @@ int GeneralResponseOp::inference() { ...@@ -89,50 +89,41 @@ int GeneralResponseOp::inference() {
output->set_engine_name(pre_name); output->set_engine_name(pre_name);
FetchInst *fetch_inst = output->add_insts(); FetchInst *fetch_inst = output->add_insts();
std::map<std::string, int> fetch_index_map;
for (int i = 0; i < in->size(); ++i) {
VLOG(2) << "index " << i << " var " << in->at(i).name;
fetch_index_map.insert(std::pair<std::string, int>(in->at(i).name, i));
}
for (auto &idx : fetch_index) { for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array(); Tensor *tensor = fetch_inst->add_tensor_array();
tensor->set_elem_type(1); tensor->set_elem_type(1);
int true_idx = fetch_index_map[model_config->_fetch_name[idx]];
if (model_config->_is_lod_fetch[idx]) { if (model_config->_is_lod_fetch[idx]) {
VLOG(2) << "out[" << idx << "] " << model_config->_fetch_name[idx] VLOG(2) << "out[" << idx << "] " << model_config->_fetch_name[idx]
<< " is lod_tensor"; << " is lod_tensor";
for (int k = 0; k < in->at(true_idx).shape.size(); ++k) { for (int k = 0; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k << "]: " << in->at(idx).shape[k]; VLOG(2) << "shape[" << k << "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(true_idx).shape[k]); tensor->add_shape(in->at(idx).shape[k]);
} }
} else { } else {
VLOG(2) << "out[" << idx << "] " << model_config->_fetch_name[idx] VLOG(2) << "out[" << idx << "] " << model_config->_fetch_name[idx]
<< " is tensor"; << " is tensor";
for (int k = 0; k < in->at(true_idx).shape.size(); ++k) { for (int k = 0; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k << "]: " << in->at(true_idx).shape[k]; VLOG(2) << "shape[" << k << "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(true_idx).shape[k]); tensor->add_shape(in->at(idx).shape[k]);
} }
} }
} }
int var_idx = 0; int var_idx = 0;
for (auto &idx : fetch_index) { for (auto &idx : fetch_index) {
int true_idx = fetch_index_map[model_config->_fetch_name[idx]];
int cap = 1; int cap = 1;
for (int j = 0; j < in->at(true_idx).shape.size(); ++j) { for (int j = 0; j < in->at(idx).shape.size(); ++j) {
cap *= in->at(true_idx).shape[j]; cap *= in->at(idx).shape[j];
} }
if (in->at(true_idx).dtype == paddle::PaddleDType::INT64) { if (in->at(idx).dtype == paddle::PaddleDType::INT64) {
VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx] VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx]
<< "]."; << "].";
int64_t *data_ptr = int64_t *data_ptr = static_cast<int64_t *>(in->at(idx).data.data());
static_cast<int64_t *>(in->at(true_idx).data.data());
if (model_config->_is_lod_fetch[idx]) { if (model_config->_is_lod_fetch[idx]) {
FetchInst *fetch_p = output->mutable_insts(0); FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < in->at(true_idx).lod[0].size(); ++j) { for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod( fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(true_idx).lod[0][j]); in->at(idx).lod[0][j]);
} }
for (int j = 0; j < cap; ++j) { for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]); fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]);
...@@ -145,15 +136,15 @@ int GeneralResponseOp::inference() { ...@@ -145,15 +136,15 @@ int GeneralResponseOp::inference() {
} }
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready"; VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++; var_idx++;
} else if (in->at(true_idx).dtype == paddle::PaddleDType::FLOAT32) { } else if (in->at(idx).dtype == paddle::PaddleDType::FLOAT32) {
VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx] VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx]
<< "]."; << "].";
float *data_ptr = static_cast<float *>(in->at(true_idx).data.data()); float *data_ptr = static_cast<float *>(in->at(idx).data.data());
if (model_config->_is_lod_fetch[idx]) { if (model_config->_is_lod_fetch[idx]) {
FetchInst *fetch_p = output->mutable_insts(0); FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < in->at(true_idx).lod[0].size(); ++j) { for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod( fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(true_idx).lod[0][j]); in->at(idx).lod[0][j]);
} }
for (int j = 0; j < cap; ++j) { for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]); fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]);
......
...@@ -33,7 +33,12 @@ def save_model(server_model_folder, ...@@ -33,7 +33,12 @@ def save_model(server_model_folder,
executor = Executor(place=CPUPlace()) executor = Executor(place=CPUPlace())
feed_var_names = [feed_var_dict[x].name for x in feed_var_dict] feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
target_vars = list(fetch_var_dict.values()) #target_vars = list(fetch_var_dict.values())
target_vars = []
target_var_names = []
for key in sorted(fetch_var_dict.keys()):
target_vars.append(fetch_var_dict[key])
target_var_names.append(key)
save_inference_model( save_inference_model(
server_model_folder, server_model_folder,
...@@ -64,7 +69,7 @@ def save_model(server_model_folder, ...@@ -64,7 +69,7 @@ def save_model(server_model_folder,
feed_var.shape.extend(tmp_shape) feed_var.shape.extend(tmp_shape)
config.feed_var.extend([feed_var]) config.feed_var.extend([feed_var])
for key in fetch_var_dict: for key in target_var_names:
fetch_var = model_conf.FetchVar() fetch_var = model_conf.FetchVar()
fetch_var.alias_name = key fetch_var.alias_name = key
fetch_var.name = fetch_var_dict[key].name fetch_var.name = fetch_var_dict[key].name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册