diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 28a37f331c100695f0ffec7288db84f4493d68a0..12ce99c8788625e2aae6e07abdea565bb2c2ebb9 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -31,10 +31,10 @@ namespace paddle { namespace framework { namespace ir { namespace { -void SortHelper( - const std::map> &adj_list, - ir::Node *node, std::unordered_set *visited, - std::vector *ret) { +void SortHelper(const std::map, + ir::NodeComp> &adj_list, + ir::Node *node, std::unordered_set *visited, + std::vector *ret) { visited->insert(node); for (auto adj : adj_list.at(node)) { @@ -50,7 +50,8 @@ void SortHelper( bool HasCircleHelper( ir::Node *node, - const std::map> &adj_list, + const std::map, ir::NodeComp> + &adj_list, std::unordered_set *visited, std::unordered_set *in_trace, std::vector> *circles) { @@ -84,7 +85,8 @@ bool HasCircleHelper( } bool HasCircleInternal( - const std::map> &adj_list, + const std::map, ir::NodeComp> + &adj_list, std::vector> *circles) { std::unordered_set visited; std::unordered_set in_trace; @@ -107,8 +109,8 @@ bool FindCircleSubGraph(const Graph &graph, } std::vector TopologySortOperations(const Graph &graph) { - std::map> adj_list = - BuildOperationAdjList(graph); + std::map, ir::NodeComp> + adj_list = BuildOperationAdjList(graph); PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr)); std::unordered_set visited; std::vector ret; @@ -117,34 +119,30 @@ std::vector TopologySortOperations(const Graph &graph) { SortHelper(adj_list, adj.first, &visited, &ret); } } + return ret; } // Build operator inlink edge table. -std::map> BuildOperationAdjList( - const Graph &graph) { - std::map> adj_list; +std::map, ir::NodeComp> +BuildOperationAdjList(const Graph &graph) { + std::map, ir::NodeComp> + adj_list; for (auto &n : graph.Nodes()) { if (!n->IsOp()) continue; if (adj_list.find(n) == adj_list.end()) { - adj_list[n] = std::unordered_set(); + adj_list[n] = std::set(); } - std::vector nodes; for (auto &var : n->inputs) { for (auto &adj_n : var->inputs) { PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) << " -> " << n->Name() << reinterpret_cast(n) << " via " << var->Name() << reinterpret_cast(var); - nodes.push_back(adj_n); + adj_list[n].insert(adj_n); } } - std::sort(nodes.begin(), nodes.end(), [](ir::Node *node1, ir::Node *node2) { - return node1->id() > node2->id(); - }); - adj_list[n].insert(std::make_move_iterator(nodes.begin()), - std::make_move_iterator(nodes.end())); } return adj_list; } diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index 214de9ec7d85aee6021b18866295777e317aa79d..849a9c3be6904f3f9c3669d8fc9d750154863031 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include "paddle/fluid/framework/ir/graph.h" @@ -25,6 +26,13 @@ namespace paddle { namespace framework { namespace ir { +// Compare nodes via node id. +struct NodeComp { + bool operator()(ir::Node *const &node1, ir::Node *const &node2) const { + return node1->id() < node2->id(); + } +}; + // Test if the graph contains circle. bool HasCircle(const Graph &graph); @@ -57,8 +65,8 @@ std::vector TopologyVarientSort(const Graph &graph, SortKind sort_kind); void CleanIndividualNodes(Graph *graph); // Build an adjacency list of operations for the `graph`. -std::map> BuildOperationAdjList( - const Graph &graph); +std::map, ir::NodeComp> +BuildOperationAdjList(const Graph &graph); template std::vector FilterByNodeWrapper(const Graph &graph) { diff --git a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc index bd4f1b61973fb0de06dcc288e329c94756d5ed47..a23297f29cf65d891f530850ffd184aa58e10886 100644 --- a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc @@ -214,28 +214,23 @@ TEST(Analyzer_Transformer, fuse_statis) { } // Compare result of NativeConfig and AnalysisConfig -// void compare(bool use_mkldnn = false) { -// AnalysisConfig cfg; -// SetConfig(&cfg); -// if (use_mkldnn) { -// cfg.EnableMKLDNN(); -// } -// -// std::vector> input_slots_all; -// SetInput(&input_slots_all); -// CompareNativeAndAnalysis( -// reinterpret_cast(&cfg), -// input_slots_all); -// } - -// TODO(yihuaxu): -// Disable compare and compare_mkldnn temporary, see -// https://github.com/paddlePaddle/Paddle/issues/16316 for details. -// TEST(Analyzer_Transformer, compare) { compare(); } -// #ifdef PADDLE_WITH_MKLDNN -// TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */); -// } -// #endif +void compare(bool use_mkldnn = false) { + AnalysisConfig cfg; + SetConfig(&cfg); + if (use_mkldnn) { + cfg.EnableMKLDNN(); + } + + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); +} + +TEST(Analyzer_Transformer, compare) { compare(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */); } +#endif } // namespace inference } // namespace paddle diff --git a/paddle/fluid/op_use_default_grad_op_maker.spec b/paddle/fluid/op_use_default_grad_op_maker.spec index 21a25ce7d5e2bad172cf50cee6138ef4b44b07c1..63eaa676a43fc784dce2437ca15bc85e2295dbb7 100644 --- a/paddle/fluid/op_use_default_grad_op_maker.spec +++ b/paddle/fluid/op_use_default_grad_op_maker.spec @@ -29,6 +29,8 @@ pool3d prelu quantize rank_loss +reduce_all +reduce_any reduce_max reduce_mean reduce_min