From 9c9e28b57ba96b60fe6289678710e36ff87cece4 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 15 Jul 2018 22:07:57 +0800 Subject: [PATCH] fix program to graph --- .../details/multi_devices_graph_builder.cc | 5 ++++- paddle/fluid/framework/ir/graph.cc | 16 ++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 035fb629a8..1e7ec95342 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -210,7 +210,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( size_t cur_device_id = 0; bool is_forwarding = true; - // TODO(panyx0718): FIXME: nodes should be sorted by "program" order. + // NOTE: Currently, passes before SSAGraphBuilder cannot reorder + // forward, backward nodes. E.g. you can't append an forward node + // at the end of the node list. + // TODO(panyx0718): FIXME: Needs to sort by forward->backward order. for (auto &node : nodes) { if (node->NodeType() != ir::Node::Type::kOperation) continue; if (boost::get( diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 1f6937658f..f8381af985 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -19,31 +19,43 @@ limitations under the License. */ namespace paddle { namespace framework { +// NOTE(paddle-dev): This graph contains circle. Graph::Graph(const ProgramDesc &program) : program_(program) { std::unordered_map all_vars; for (auto *var : program.Block(0).AllVars()) { all_vars.emplace(var->Name(), var); } + std::map var_nodes; for (auto *op : program.Block(0).AllOps()) { ir::Node *node = CreateOpNode(op); for (auto &each_var_name : op->InputArgumentNames()) { ir::Node *var = nullptr; - if (all_vars.count(each_var_name) != 0) { + if (var_nodes.find(each_var_name) != var_nodes.end()) { + var = var_nodes.at(each_var_name); + } else if (all_vars.count(each_var_name) != 0) { var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name] = var; } else { // TODO(paddle-dev): Seems some assumption doesn't hold? LOG(ERROR) << op->Type() << " input var not in all_var list: " << each_var_name; var = CreateEmptyNode(each_var_name); + var_nodes[each_var_name] = var; } node->inputs.push_back(var); var->outputs.push_back(node); } for (auto &each_var_name : op->OutputArgumentNames()) { - ir::Node *var = CreateVarNode(all_vars.at(each_var_name)); + ir::Node *var = nullptr; + if (var_nodes.find(each_var_name) != var_nodes.end()) { + var = var_nodes.at(each_var_name); + } else { + var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name] = var; + } node->outputs.push_back(var); var->inputs.push_back(node); } -- GitLab