提交 cf547e27 编写于 作者: X Xin Pan

fix program_desc feed/fetch names' order.

上级 08352fe5
......@@ -117,10 +117,16 @@ void ProgramDesc::InitFromProto() {
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
auto &global_block = Block(0);
// The order of feed_target_names must follow the index specified in `col`.
// since feed operator's order doesn't necessary follow 'col'.
std::vector<std::string> feed_target_names;
for (auto *op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) {
feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]);
int col = boost::get<int>(op->GetAttr("col"));
if (col >= feed_target_names.size()) {
feed_target_names.resize(col + 1);
}
feed_target_names[col] = op->Output("Out")[0];
}
}
return feed_target_names;
......@@ -128,10 +134,16 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
auto &global_block = Block(0);
// The order of fetch_target_names must follow the index specified in `col`.
// since fetch operator's order doesn't necessary follow 'col'.
std::vector<std::string> fetch_target_names;
for (auto *op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) {
fetch_target_names.push_back(op->Input("X")[0]);
int col = boost::get<int>(op->GetAttr("col"));
if (col >= fetch_target_names.size()) {
fetch_target_names.resize(col + 1);
}
fetch_target_names[col] = op->Input("X")[0];
}
}
return fetch_target_names;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册