diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 035fb629a89cb3117dd5012dcca93e4f3486e3e8..1e7ec95342e406523b2444918220780b8bbb62be 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -210,7 +210,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( size_t cur_device_id = 0; bool is_forwarding = true; - // TODO(panyx0718): FIXME: nodes should be sorted by "program" order. + // NOTE: Currently, passes before SSAGraphBuilder cannot reorder + // forward, backward nodes. E.g. you can't append an forward node + // at the end of the node list. + // TODO(panyx0718): FIXME: Needs to sort by forward->backward order. for (auto &node : nodes) { if (node->NodeType() != ir::Node::Type::kOperation) continue; if (boost::get( diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 1f6937658f99f62be80e8ea671eba4d69d8da189..f8381af985a129d153a60daefabe2e4d3fafa375 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -19,31 +19,43 @@ limitations under the License. */ namespace paddle { namespace framework { +// NOTE(paddle-dev): This graph contains circle. Graph::Graph(const ProgramDesc &program) : program_(program) { std::unordered_map all_vars; for (auto *var : program.Block(0).AllVars()) { all_vars.emplace(var->Name(), var); } + std::map var_nodes; for (auto *op : program.Block(0).AllOps()) { ir::Node *node = CreateOpNode(op); for (auto &each_var_name : op->InputArgumentNames()) { ir::Node *var = nullptr; - if (all_vars.count(each_var_name) != 0) { + if (var_nodes.find(each_var_name) != var_nodes.end()) { + var = var_nodes.at(each_var_name); + } else if (all_vars.count(each_var_name) != 0) { var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name] = var; } else { // TODO(paddle-dev): Seems some assumption doesn't hold? LOG(ERROR) << op->Type() << " input var not in all_var list: " << each_var_name; var = CreateEmptyNode(each_var_name); + var_nodes[each_var_name] = var; } node->inputs.push_back(var); var->outputs.push_back(node); } for (auto &each_var_name : op->OutputArgumentNames()) { - ir::Node *var = CreateVarNode(all_vars.at(each_var_name)); + ir::Node *var = nullptr; + if (var_nodes.find(each_var_name) != var_nodes.end()) { + var = var_nodes.at(each_var_name); + } else { + var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name] = var; + } node->outputs.push_back(var); var->inputs.push_back(node); }