diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 3914e08d99547b22d82ec187fe405dc7b551373c..a174aa88d937bf2b9786863b5e21dedd2fc1af8f 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/operator.h" -DEFINE_bool(convert_all_blocks, false, +DEFINE_bool(convert_all_blocks, true, "Convert all blocks in program into SSAgraphs"); namespace paddle { diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 7b6002da0966f57fbcbd36018ece3159d4403ef8..0a856330f8e742b5fc2bb797f1402174dc786889 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -24,15 +24,17 @@ namespace paddle { namespace framework { namespace ir { namespace { -void SortHelper(const std::map, - ir::NodeComp> &adj_list, + +template +void SortHelper(const std::map, + NodeComparator> &adj_list, ir::Node *node, std::unordered_set *visited, std::vector *ret) { visited->insert(node); for (auto adj : adj_list.at(node)) { if (visited->find(adj) == visited->end()) { - SortHelper(adj_list, adj, visited, ret); + SortHelper(adj_list, adj, visited, ret); } } @@ -41,10 +43,11 @@ void SortHelper(const std::map, ret->push_back(node); } +template bool HasCircleHelper( ir::Node *node, - const std::map, ir::NodeComp> - &adj_list, + const std::map, + NodeComparator> &adj_list, std::unordered_set *visited, std::unordered_set *in_trace, std::vector> *circles) { @@ -54,7 +57,8 @@ bool HasCircleHelper( for (ir::Node *in : adj_list.at(node)) { if (visited->find(in) == visited->end() && - HasCircleHelper(in, adj_list, visited, in_trace, circles)) { + HasCircleHelper(in, adj_list, visited, in_trace, + circles)) { return true; } else if (in_trace->find(in) != in_trace->end()) { if (circles != nullptr) { @@ -77,14 +81,16 @@ bool HasCircleHelper( return false; } +template bool HasCircleInternal( - const std::map, ir::NodeComp> - &adj_list, + const std::map, + NodeComparator> &adj_list, std::vector> *circles) { std::unordered_set visited; std::unordered_set in_trace; for (auto &adj : adj_list) { - if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace, circles)) { + if (HasCircleHelper(adj.first, adj_list, &visited, + &in_trace, circles)) { return true; } } @@ -136,7 +142,7 @@ std::vector TopologySortOperations(const Graph &graph) { std::vector ret; for (auto adj : adj_list) { if (visited.find(adj.first) == visited.end()) { - SortHelper(adj_list, adj.first, &visited, &ret); + SortHelper(adj_list, adj.first, &visited, &ret); } } @@ -169,34 +175,6 @@ bool IsTopologySortOperationsUnique(const Graph &graph) { return true; } -// Build operator inlink edge table. -std::map, ir::NodeComp> -BuildOperationAdjList(const Graph &graph) { - std::map, ir::NodeComp> - adj_list; - - for (auto &n : graph.Nodes()) { - if (!n->IsOp()) continue; - if (adj_list.find(n) == adj_list.end()) { - adj_list[n] = std::set(); - } - for (auto &var : n->inputs) { - for (auto &adj_n : var->inputs) { - PADDLE_ENFORCE_EQ( - adj_n->NodeType(), ir::Node::Type::kOperation, - platform::errors::InvalidArgument( - "Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(), - static_cast(adj_n->NodeType()))); - VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) - << " -> " << n->Name() << reinterpret_cast(n) - << " via " << var->Name() << reinterpret_cast(var); - adj_list[n].insert(adj_n); - } - } - } - return adj_list; -} - // Build operator outlink edge table. std::map> BuildOperationOutAdjList( const Graph &graph) { @@ -424,81 +402,33 @@ std::vector TopologyVarientSort(const Graph &graph, class DescOrderComparator { public: - bool operator()(const Node *n1, const Node *n2) { - return (n1->DescOrder() > n2->DescOrder()) || - ((n1->DescOrder() == n2->DescOrder()) && - (n1->ToString() > n2->ToString())); + bool operator()(Node *const &n1, Node *const &n2) const { + if (n1->DescOrder() < n2->DescOrder()) { + return true; + } else if (n1->DescOrder() == n2->DescOrder()) { + return n1->id() < n2->id() || + (n1->id() == n2->id() && n1->ToString() < n2->ToString()); + } + return false; } }; std::vector TopologySortGraphByDescOrder(const Graph &graph) { - std::vector sorted_ops; - std::priority_queue, DescOrderComparator> q; - std::unordered_map> in_ops; - std::unordered_map> out_ops; - - // ensure all op node in 'in_ops' and 'out_ops' - for (const auto &n : graph.Nodes()) { - if (!n->IsOp()) continue; - - in_ops.emplace(n, std::unordered_set()); - out_ops.emplace(n, std::unordered_set()); - } - - // record all op's input op and output op - for (const auto &n : graph.Nodes()) { - if (!n->IsOp()) continue; - - // traverse all input op - for (const auto &var : n->inputs) { - for (const auto &in : var->inputs) { - // use at instead of [] to prevent no unrecorded op node - in_ops.at(n).insert(in); - out_ops.at(in).insert(n); - } - } - } - - // find topology entrance - for (const auto &n : graph.Nodes()) { - if (!n->IsOp()) continue; - - if (in_ops.at(n).empty()) { - q.push(n); - } - } - - // topological sorting - while (!q.empty()) { - // Do not get by reference!!! The element will pop later. - const auto cur_op = q.top(); - q.pop(); - - sorted_ops.push_back(cur_op); - for (const auto &out : out_ops.at(cur_op)) { - PADDLE_ENFORCE_GT(in_ops.at(out).count(cur_op), 0, - platform::errors::InvalidArgument( - "We find %s in %s's output list, " - "but cannot find %s in %s's input list. " - "Please ensure graph completely.", - out->Name().c_str(), cur_op->Name().c_str(), - cur_op->Name().c_str(), out->Name().c_str())); - in_ops.at(out).erase(cur_op); - - // push if in-degree is 0 - if (in_ops.at(out).empty()) { - q.push(out); - } + std::map, + DescOrderComparator> + adj_list = BuildOperationAdjList(graph); + PADDLE_ENFORCE_EQ(HasCircleInternal(adj_list, nullptr), + false, platform::errors::InvalidArgument( + "Generated graph shouldn't contain cycle.")); + std::unordered_set visited; + std::vector ret; + for (auto adj : adj_list) { + if (visited.find(adj.first) == visited.end()) { + SortHelper(adj_list, adj.first, &visited, &ret); } } - PADDLE_ENFORCE_EQ( - sorted_ops.size(), in_ops.size(), - platform::errors::InvalidArgument("Topological sorting incompletely, " - "only sorted %zd op but total %zd.", - sorted_ops.size(), in_ops.size())); - - return sorted_ops; + return ret; } } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index 3c3ea662502b57ad757ec19128dce49340e3c0bf..3309f600730e8c3fa4e5a3ab5a186e1550a61cf0 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -77,9 +77,34 @@ std::vector TopologyVarientSort(const Graph &graph, SortKind sort_kind); // Clean the nodes that doesn't connect to others. void CleanIndividualNodes(Graph *graph); -// Build an adjacency list of operations for the `graph`. -std::map, ir::NodeComp> -BuildOperationAdjList(const Graph &graph); +// Build an in-link adjacency list of operations for the `graph`. +template +std::map, NodeComparator> +BuildOperationAdjList(const Graph &graph) { + std::map, NodeComparator> + adj_list; + + for (auto &n : graph.Nodes()) { + if (!n->IsOp()) continue; + if (adj_list.find(n) == adj_list.end()) { + adj_list[n] = std::set(); + } + for (auto &var : n->inputs) { + for (auto &adj_n : var->inputs) { + PADDLE_ENFORCE_EQ( + adj_n->NodeType(), ir::Node::Type::kOperation, + platform::errors::InvalidArgument( + "Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(), + static_cast(adj_n->NodeType()))); + VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) + << " -> " << n->Name() << reinterpret_cast(n) + << " via " << var->Name() << reinterpret_cast(var); + adj_list[n].insert(adj_n); + } + } + } + return adj_list; +} template std::vector FilterByNodeWrapper(const Graph &graph) { diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index abc10765e4a37000412534e5396b7e9ef792a00d..fc8d7ac949a0217931a33d58e8c506a86ab6eba4 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -31,6 +31,7 @@ namespace py = pybind11; using paddle::framework::ir::Graph; using paddle::framework::ir::Node; +using paddle::framework::ir::NodeComp; using paddle::framework::ir::GraphSafeRemoveNodes; using paddle::framework::ir::HasCircle; using paddle::framework::ir::GraphNum; @@ -50,7 +51,7 @@ void BindGraph(py::module *m) { m->def("graph_num", GraphNum); m->def("topology_sort", TopologySortOperations, return_value_policy::reference); - m->def("build_adjacency_list", BuildOperationAdjList, + m->def("build_adjacency_list", BuildOperationAdjList, return_value_policy::reference); py::class_>( *m, "Graph", diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 2986ffc1116508e25e43c1fe89e4b6efce14dd89..fcb2641710facb95d9ae71152a8d02844afb629a 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -179,6 +179,7 @@ def __bootstrap__(): sysstr = platform.system() read_env_flags = [ 'check_nan_inf', + 'convert_all_blocks', 'benchmark', 'eager_delete_scope', 'fraction_of_cpu_memory_to_use',