提交 9888ac01 编写于 作者: M MRXLT

fix predict return

上级 555916b8
...@@ -258,9 +258,10 @@ int PredictorClient::batch_predict( ...@@ -258,9 +258,10 @@ int PredictorClient::batch_predict(
ModelRes model; ModelRes model;
model.set_engine_name(output.engine_name()); model.set_engine_name(output.engine_name());
int idx = 0;
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name]; // int idx = _fetch_name_to_idx[name];
int idx = 0;
int shape_size = output.insts(0).tensor_array(idx).shape_size(); int shape_size = output.insts(0).tensor_array(idx).shape_size();
VLOG(2) << "fetch var " << name << " index " << idx << " shape size " VLOG(2) << "fetch var " << name << " index " << idx << " shape size "
<< shape_size; << shape_size;
...@@ -279,9 +280,9 @@ int PredictorClient::batch_predict( ...@@ -279,9 +280,9 @@ int PredictorClient::batch_predict(
idx += 1; idx += 1;
} }
idx = 0;
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name]; // int idx = _fetch_name_to_idx[name];
int idx = 0;
if (_fetch_name_to_type[name] == 0) { if (_fetch_name_to_type[name] == 0) {
VLOG(2) << "ferch var " << name << "type int"; VLOG(2) << "ferch var " << name << "type int";
model._int64_value_map[name].resize( model._int64_value_map[name].resize(
...@@ -536,9 +537,9 @@ int PredictorClient::numpy_predict( ...@@ -536,9 +537,9 @@ int PredictorClient::numpy_predict(
ModelRes model; ModelRes model;
model.set_engine_name(output.engine_name()); model.set_engine_name(output.engine_name());
int idx = 0;
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name]; // int idx = _fetch_name_to_idx[name];
int idx = 0;
int shape_size = output.insts(0).tensor_array(idx).shape_size(); int shape_size = output.insts(0).tensor_array(idx).shape_size();
VLOG(2) << "fetch var " << name << " index " << idx << " shape size " VLOG(2) << "fetch var " << name << " index " << idx << " shape size "
<< shape_size; << shape_size;
...@@ -557,9 +558,10 @@ int PredictorClient::numpy_predict( ...@@ -557,9 +558,10 @@ int PredictorClient::numpy_predict(
idx += 1; idx += 1;
} }
idx = 0;
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name]; // int idx = _fetch_name_to_idx[name];
int idx = 0;
if (_fetch_name_to_type[name] == 0) { if (_fetch_name_to_type[name] == 0) {
VLOG(2) << "ferch var " << name << "type int"; VLOG(2) << "ferch var " << name << "type int";
model._int64_value_map[name].resize( model._int64_value_map[name].resize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册