提交 5adf118a 编写于 作者: X Xin Pan

polish

上级 c558f059
......@@ -67,14 +67,12 @@ void NativePaddlePredictor::PrepareFeedFetch() {
}
feeds_[idx] = op;
feed_names_[op->Output("Out")[0]] = idx;
LOG(ERROR) << "feed " << idx << " " << op->Output("Out")[0];
} else if (op->Type() == "fetch") {
int idx = boost::get<int>(op->GetAttr("col"));
if (fetchs_.size() <= idx) {
fetchs_.resize(idx + 1);
}
fetchs_[idx] = op;
LOG(ERROR) << "fetch " << idx << " " << op->Input("X")[0];
}
}
}
......@@ -216,8 +214,7 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
input.set_lod(lod);
int idx = -1;
if (config_.specify_input_name) {
idx =
boost::get<int>(feeds_[feed_names_[inputs[i].name]]->GetAttr("col"));
idx = feed_names_[inputs[i].name];
} else {
idx = boost::get<int>(feeds_[i]->GetAttr("col"));
}
......@@ -231,7 +228,6 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
VLOG(3) << "Predictor::get_fetch";
outputs->resize(fetchs_.size());
for (size_t i = 0; i < fetchs_.size(); ++i) {
std::string fetch_target_name = fetchs_[i]->Input("X")[0];
int idx = boost::get<int>(fetchs_[i]->GetAttr("col"));
PADDLE_ENFORCE(idx == i);
framework::LoDTensor &output =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册