diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index df55b3d05402f1aeecfd8d4218a637a81d58ed87..620d202d33014ebc6142d7e0065e569cb0613e4d 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,7 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) -cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph) +cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph graph_helper) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index dc9183e96aa6ac898e24e162177a1865a097ab1b..4cc6e5727b2bc1b3e1bba03b0a1de40125f1ef54 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -25,6 +25,7 @@ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" @@ -186,9 +187,55 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( return dev_id; } +// Topology sort the graph nodes from inputs to outputs. +// Since SSAGraphBuilder depends on forward/backward nodes to assign devices +// to parameter/gradients before optimizer ops, topo sort is insufficient. ( +// some optimizer ops might not depend on any nodes), we manually move all +// optimizer nodes after last backward nodes. +std::vector SortOpsAndDelayOptimizeOp(const Graph &graph) { + std::vector ret = ir::TopologySort(graph); + size_t last_backward = 0; + std::vector optimize_ops; + std::vector sorted_ret; + for (size_t i = 0; i < ret.size(); ++i) { + if (boost::get( + ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kBackward)) { + sorted_ret.push_back(ret[i]); + last_backward = sorted_ret.size(); + } else if (boost::get(ret[i]->Op()->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kOptimize)) { + optimize_ops.push_back(ret[i]); + } else { + sorted_ret.push_back(ret[i]); + } + } + + sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(), + optimize_ops.end()); + + for (ir::Node *n : sorted_ret) { + n->inputs.erase(std::remove_if(n->inputs.begin(), n->inputs.end(), + [n](ir::Node *t) { + return t->Name() == + ir::Node::kControlDepVarName; + }), + n->inputs.end()); + n->outputs.erase(std::remove_if(n->outputs.begin(), n->outputs.end(), + [n](ir::Node *t) { + return t->Name() == + ir::Node::kControlDepVarName; + }), + n->outputs.end()); + } + return sorted_ret; +} + std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unique_ptr graph) const { // Rebuild the graph structure. + std::vector sorted_ops = SortOpsAndDelayOptimizeOp(*graph); auto nodes = std::move(graph->nodes); graph->nodes.clear(); @@ -217,12 +264,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( size_t cur_device_id = 0; bool is_forwarding = true; - // NOTE: Currently, passes before SSAGraphBuilder cannot reorder - // forward, backward nodes. E.g. you can't append an forward node - // at the end of the node list. - // TODO(panyx0718): FIXME: Needs to sort by forward->backward order. - for (ir::Node *node : TopologySortOperationFromInToOut(nodes)) { - VLOG(3) << "apply node: " << node->Name() << reinterpret_cast(node); + for (ir::Node *node : sorted_ops) { if (boost::get( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { @@ -240,7 +282,6 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // It also assumes backward op will always follow the forward op in // the block. is_forwarding = false; - LOG(ERROR) << "forward flipping!!!!!!!"; } else { int op_dev_id = GetOpDeviceID(node); if (op_dev_id != -1) { // This op only runs on one specific device. diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ee0604383ec9df826fa2abaef1f643ba0da6a096..744696ebb0cdcdc5f7fa9a94bcbf7fa839157b67 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -1,5 +1,5 @@ cc_library(node SRCS node.cc DEPS proto_desc) cc_library(graph SRCS graph.cc DEPS node) +cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(pass SRCS pass.cc DEPS graph node) - cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index f297461ab27df653a529b2d08320a8bf95daac9c..46640fedcce16079abdebe61bc3c8fb87f8822eb 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -22,6 +22,7 @@ limitations under the License. */ namespace paddle { namespace framework { +/* namespace { void SortHelper( const std::map> &adj_list, @@ -39,7 +40,21 @@ void SortHelper( << reinterpret_cast(node) << " input " << node->inputs.size(); ret->push_back(node); } + +std::vector TopologySort( + const std::map> &adj_list) { + std::unordered_set visited; + std::vector ret; + + for (auto adj : adj_list) { + if (visited.find(adj.first) == visited.end()) { + SortHelper(adj_list, adj.first, &visited, &ret); + } + } + return ret; +} } // namespace +*/ Graph::Graph(const ProgramDesc &program) : program_(program) { VLOG(3) << "block in program:" << program_.Size(); @@ -48,20 +63,9 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { all_vars.emplace(var->Name(), var); } - ir::Node *last_backward = nullptr; - std::vector optimize_ops; std::map> var_nodes; for (auto *op : program.Block(0).AllOps()) { ir::Node *node = CreateOpNode(op); - if (boost::get( - op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == - static_cast(OpRole::kBackward)) { - last_backward = node; - } else if (boost::get( - op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == - static_cast(OpRole::kOptimize)) { - optimize_ops.push_back(node); - } for (auto &each_var_name : op->InputArgumentNames()) { ir::Node *var = nullptr; @@ -106,7 +110,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { // Read Write is the same op. continue; } - ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable); + ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName, + ir::Node::Type::kVariable); read_op->outputs.push_back(dep_var); dep_var->inputs.push_back(read_op); write_op->inputs.push_back(dep_var); @@ -114,62 +119,121 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { } } } +} - if (last_backward) { - for (ir::Node *opt_node : optimize_ops) { - ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable); - last_backward->outputs.push_back(dep_var); - dep_var->inputs.push_back(last_backward); - opt_node->inputs.push_back(dep_var); - dep_var->outputs.push_back(opt_node); - VLOG(3) << "appending connect: " << last_backward->Name() - << reinterpret_cast(last_backward) << "->" - << opt_node->Name() << reinterpret_cast(opt_node); +/* +bool HasCircleHelper(ir::Node* node, + const std::map> +&adj_list, + std::unordered_set* visited, + std::unordered_set* in_trace) { + if (visited->find(node) == visited->end()) { + visited->insert(node); + in_trace->insert(node); + + for (ir::Node *in : adj_list.at(node)) { + if (visited->find(in) == visited->end() && + HasCircleHelper(in, adj_list, visited, in_trace)) { + return true; + } else if (in_trace->find(in) != in_trace->end()) { + return true; + } } } + in_trace->erase(node); + return false; } -std::vector TopologySortOperationFromInToOut( - const std::vector> &nodes) { +bool HasCircle(const std::map> +&adj_list) { + std::unordered_set visited; + std::unordered_set in_trace; + for (auto& adj : adj_list) { + if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) { + return true; + } + } + return false; +} + +std::map> BuildAdjList( + const std::vector &nodes) { std::map> adj_list; - std::unordered_set visited; - std::vector ret; for (auto &n : nodes) { if (n->NodeType() != ir::Node::Type::kOperation) continue; - if (adj_list.find(n.get()) == adj_list.end()) { - adj_list[n.get()] = std::unordered_set(); + if (adj_list.find(n) == adj_list.end()) { + adj_list[n] = std::unordered_set(); } for (auto &var : n->inputs) { for (auto &adj_n : var->inputs) { PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); - adj_list[n.get()].insert(adj_n); + adj_list[n].insert(adj_n); LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) - << " -> " << n->Name() << reinterpret_cast(n.get()) + << " -> " << n->Name() << reinterpret_cast(n) << " via " << var->Name() << reinterpret_cast(var); } } } + return adj_list; +} - for (auto adj : adj_list) { - if (visited.find(adj.first) == visited.end()) { - SortHelper(adj_list, adj.first, &visited, &ret); +std::vector TopologySortOperationFromInToOut( + const std::vector> &nodes) { + std::vector tmp; + for (auto& n : nodes) { + tmp.push_back(n.get()); + } + std::map> adj_list = +BuildAdjList(tmp); + + PADDLE_ENFORCE(!HasCircle(adj_list)); + std::vector ret = TopologySort(adj_list); + + ir::Node *last_backward = nullptr; + std::vector optimize_ops; + for (ir::Node* n : ret) { + if (boost::get( + n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kBackward)) { + last_backward = n; + } else if (boost::get( + n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kOptimize)) { + optimize_ops.push_back(n); } } + if (last_backward) { + for (ir::Node *opt_node : optimize_ops) { + ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName, + ir::Node::Type::kVariable); + last_backward->outputs.push_back(dep_var); + dep_var->inputs.push_back(last_backward); + opt_node->inputs.push_back(dep_var); + dep_var->outputs.push_back(opt_node); + VLOG(3) << "appending connect: " << last_backward->Name() + << reinterpret_cast(last_backward) << "->" + << opt_node->Name() << reinterpret_cast(opt_node); + } + } + + PADDLE_ENFORCE(!HasCircle(adj_list)); for (ir::Node *n : ret) { std::unordered_set dummy; n->inputs.erase( std::remove_if(n->inputs.begin(), n->inputs.end(), - [n](ir::Node *t) { return t->Name() == "dummy"; }), + [n](ir::Node *t) { + return t->Name() == ir::Node::kControlDepVarName; }), n->inputs.end()); n->outputs.erase( std::remove_if(n->outputs.begin(), n->outputs.end(), - [n](ir::Node *t) { return t->Name() == "dummy"; }), + [n](ir::Node *t) { + return t->Name() == ir::Node::kControlDepVarName; }), n->outputs.end()); } return ret; -} +}*/ } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 0242edecf4525fad45d9203740997035587e7130..b4ac135b029005b723abca2cb9b9a9aa175eda40 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -78,8 +78,5 @@ class Graph { std::map> attr_dels_; }; -std::vector TopologySortOperationFromInToOut( - const std::vector>& nodes); - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..ecd90f4f3ec699298793a58ec3ce3d2d4a41ab03 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace { +void SortHelper( + const std::map> &adj_list, + ir::Node *node, std::unordered_set *visited, + std::vector *ret) { + visited->insert(node); + + for (auto adj : adj_list.at(node)) { + if (visited->find(adj) == visited->end()) { + SortHelper(adj_list, adj, visited, ret); + } + } + + LOG(ERROR) << "topology sort insert: " << node->Name() + << reinterpret_cast(node) << " input " + << node->inputs.size(); + ret->push_back(node); +} + +bool HasCircleHelper( + ir::Node *node, + const std::map> &adj_list, + std::unordered_set *visited, + std::unordered_set *in_trace) { + if (visited->find(node) == visited->end()) { + visited->insert(node); + in_trace->insert(node); + + for (ir::Node *in : adj_list.at(node)) { + if (visited->find(in) == visited->end() && + HasCircleHelper(in, adj_list, visited, in_trace)) { + return true; + } else if (in_trace->find(in) != in_trace->end()) { + return true; + } + } + } + in_trace->erase(node); + return false; +} +} // namespace + +bool HasCircle(const Graph &graph) { + std::map> adj_list = + BuildAdjList(graph); + + std::unordered_set visited; + std::unordered_set in_trace; + for (auto &adj : adj_list) { + if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) { + return true; + } + } + return false; +} + +std::vector TopologySort(const Graph &graph) { + std::map> adj_list = + BuildAdjList(graph); + std::unordered_set visited; + std::vector ret; + for (auto adj : adj_list) { + if (visited.find(adj.first) == visited.end()) { + SortHelper(adj_list, adj.first, &visited, &ret); + } + } + return ret; +} + +std::map> BuildAdjList( + const Graph &graph) { + std::map> adj_list; + + for (auto &n : graph.nodes) { + if (n->NodeType() != ir::Node::Type::kOperation) continue; + if (adj_list.find(n.get()) == adj_list.end()) { + adj_list[n.get()] = std::unordered_set(); + } + for (auto &var : n->inputs) { + for (auto &adj_n : var->inputs) { + PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); + adj_list[n.get()].insert(adj_n); + LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) + << " -> " << n->Name() << reinterpret_cast(n.get()) + << " via " << var->Name() << reinterpret_cast(var); + } + } + } + return adj_list; +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..b8714eb5be03143657c22903b0283af7d76f83bd --- /dev/null +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework { +namespace ir { +bool HasCircle(const Graph &graph); + +std::vector TopologySort(const Graph &graph); + +std::map> BuildAdjList( + const Graph &graph); +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index 86376e7e8bc8bee2ddbc18f7f24bcdd849a06cbf..aca77da8d674f29b89c023717cdcd061232d023a 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -15,5 +15,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/node.h" namespace paddle { -namespace framework {} // namespace framework +namespace framework { +namespace ir { +const char Node::kControlDepVarName[] = "__control_var"; +} // namespace ir +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 97b64a6017ef08ffc73ae22beb18321934506078..b3138fccee86fb274abe72007961fc1c982b1e96 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -27,6 +27,8 @@ namespace ir { class Node { public: enum class Type { kOperation, kVariable }; + static const char kControlDepVarName[]; + explicit Node(const std::string& name, Type type) : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}