From 796c87d56366ea7fccd5f511b057121ca20ee65e Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 4 Sep 2018 12:24:38 +0800 Subject: [PATCH] bugfix/fusion lstm (#13185) --- paddle/fluid/framework/ir/fc_fuse_pass.cc | 2 +- .../framework/ir/graph_pattern_detector.cc | 7 +++++- .../framework/ir/graph_pattern_detector.h | 2 ++ .../ir/graph_pattern_detector_tester.cc | 5 ++-- .../framework/ir/infer_clean_graph_pass.cc | 23 +++++++++---------- .../inference/analysis/analyzer_tester.cc | 17 +++++++++++--- 6 files changed, 37 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 513742bab69..4bdc21a47fb 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -77,7 +77,7 @@ bool LinksReplace(std::vector* links, Node* from, Node* to) { std::unique_ptr FCFusePass::ApplyImpl( std::unique_ptr graph) const { PADDLE_ENFORCE(graph.get()); - FusePassBase::Init("fc", graph.get()); + FusePassBase::Init("fc_fuse", graph.get()); std::unordered_set nodes2delete; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index f651ab635ea..16b51423d25 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -111,6 +111,11 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { return false; } } + for (auto& item : pdnodes2nodes_) { + for (auto& n : item.second) { + GetMarkedNodes(const_cast(&graph)).insert(n); + } + } VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; return !pdnodes2nodes_.empty(); @@ -278,7 +283,7 @@ void GraphPatternDetector::RemoveOverlappedMatch( for (const auto& subgraph : *subgraphs) { bool valid = true; for (auto& item : subgraph) { - if (node_set.count(item.second)) { + if (item.first->IsIntermediate() && node_set.count(item.second)) { valid = false; break; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 024ce8ce556..e27246801a6 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -245,6 +245,8 @@ class GraphPatternDetector { void UniquePatterns(std::vector* subgraphs); // Remove overlapped match subgraphs, when overlapped, keep the previous one. + // The intermediate PDNodes will be removed, so can't shared by multiple + // patterns. void RemoveOverlappedMatch(std::vector* subgraphs); // Validate whether the intermediate nodes are linked by external nodes. diff --git a/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc index 7e5c86b033a..6c466fb21fb 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc @@ -140,8 +140,9 @@ TEST(GraphPatternDetecter, MultiSubgraph) { return node->IsOp() && (node->Name() == "op2" || node->Name() == "op3"); }, "OP0"); - auto* any_var = x.mutable_pattern()->NewNode( - [](Node* node) { return node->IsVar(); }, "VAR"); + auto* any_var = x.mutable_pattern() + ->NewNode([](Node* node) { return node->IsVar(); }, "VAR") + ->AsIntermediate(); auto* any_op1 = x.mutable_pattern()->NewNode( [](Node* node) { return node->IsOp(); }, "OP1"); diff --git a/paddle/fluid/framework/ir/infer_clean_graph_pass.cc b/paddle/fluid/framework/ir/infer_clean_graph_pass.cc index f885567da19..7713ed1eab8 100644 --- a/paddle/fluid/framework/ir/infer_clean_graph_pass.cc +++ b/paddle/fluid/framework/ir/infer_clean_graph_pass.cc @@ -13,42 +13,41 @@ // limitations under the License. #include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { namespace ir { -class InferCleanGraphPass : public Pass { +class InferCleanGraphPass : public FusePassBase { public: virtual ~InferCleanGraphPass() {} protected: std::unique_ptr ApplyImpl(std::unique_ptr graph) const { + FusePassBase::Init("original_graph", graph.get()); PADDLE_ENFORCE(graph.get()); auto is_valid_node = [](Node* x) { return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); }; - std::unordered_set invalid_nodes; + std::unordered_set invalid_nodes; + int valid_op = 0; for (auto* node : graph->Nodes()) { if (is_valid_node(node)) { invalid_nodes.insert(node); + } else if (node->IsOp()) { + // Collect all the operators to help tracking number of operators. + ++valid_op; } } - // remove nodes from the graph. - for (auto* node : invalid_nodes) { - graph->RemoveNode(node); - } + GraphSafeRemoveNodes(graph.get(), invalid_nodes); - // clean edges. - for (auto* node : graph->Nodes()) { - CleanEdges(&node->inputs, invalid_nodes); - CleanEdges(&node->outputs, invalid_nodes); - } + AddStatis(valid_op); return graph; } diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index ec1f3979a74..0e4d65cc859 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -327,9 +327,20 @@ void TestDituRNNPrediction(const std::string &model_path, LOG(INFO) << "fused " << item.first << " " << item.second; } - ASSERT_TRUE(fuse_statis.count("fc")); - EXPECT_EQ(fuse_statis.at("fc"), 1); - EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 1); + int num_ops = 0; + for (auto &node : + analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) { + if (node->IsFunction()) { + ++num_ops; + } + } + LOG(INFO) << "has num ops: " << num_ops; + + ASSERT_TRUE(fuse_statis.count("fc_fuse")); + EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); + EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM + EXPECT_EQ(num_ops, + 13); // After graph optimization, only 13 operators exists. } } -- GitLab