未验证 提交 fa842ae6 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #425 from MRXLT/flexible-shape

fix fetch var
...@@ -309,7 +309,7 @@ int PredictorClient::batch_predict( ...@@ -309,7 +309,7 @@ int PredictorClient::batch_predict(
tensor_vec.push_back(inst->add_tensor_array()); tensor_vec.push_back(inst->add_tensor_array());
} }
VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name" VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name "
<< "prepared"; << "prepared";
int vec_idx = 0; int vec_idx = 0;
for (auto &name : float_feed_name) { for (auto &name : float_feed_name) {
...@@ -375,9 +375,11 @@ int PredictorClient::batch_predict( ...@@ -375,9 +375,11 @@ int PredictorClient::batch_predict(
predict_res_batch._int64_map[name].resize(batch_size); predict_res_batch._int64_map[name].resize(batch_size);
predict_res_batch._float_map[name].resize(batch_size); predict_res_batch._float_map[name].resize(batch_size);
} }
VLOG(2) << "response batch size " << res.insts_size();
VLOG(2) << "response var nmae " << res.insts(0).tensor_array_size();
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
int idx = 0;
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name];
int len = res.insts(bi).tensor_array(idx).data_size(); int len = res.insts(bi).tensor_array(idx).data_size();
if (_fetch_name_to_type[name] == 0) { if (_fetch_name_to_type[name] == 0) {
int len = res.insts(bi).tensor_array(idx).int64_data_size(); int len = res.insts(bi).tensor_array(idx).int64_data_size();
...@@ -401,6 +403,7 @@ int PredictorClient::batch_predict( ...@@ -401,6 +403,7 @@ int PredictorClient::batch_predict(
res.insts(bi).tensor_array(idx).float_data(i); res.insts(bi).tensor_array(idx).float_data(i);
} }
} }
idx += 1;
} }
} }
postprocess_end = timeline.TimeStampUS(); postprocess_end = timeline.TimeStampUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册