未验证 提交 796c87d5 编写于 作者: Y Yan Chunwei 提交者: GitHub

bugfix/fusion lstm (#13185)

上级 9557cc21
...@@ -77,7 +77,7 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) { ...@@ -77,7 +77,7 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph.get());
FusePassBase::Init("fc", graph.get()); FusePassBase::Init("fc_fuse", graph.get());
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
......
...@@ -111,6 +111,11 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { ...@@ -111,6 +111,11 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
return false; return false;
} }
} }
for (auto& item : pdnodes2nodes_) {
for (auto& n : item.second) {
GetMarkedNodes(const_cast<Graph*>(&graph)).insert(n);
}
}
VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty(); return !pdnodes2nodes_.empty();
...@@ -278,7 +283,7 @@ void GraphPatternDetector::RemoveOverlappedMatch( ...@@ -278,7 +283,7 @@ void GraphPatternDetector::RemoveOverlappedMatch(
for (const auto& subgraph : *subgraphs) { for (const auto& subgraph : *subgraphs) {
bool valid = true; bool valid = true;
for (auto& item : subgraph) { for (auto& item : subgraph) {
if (node_set.count(item.second)) { if (item.first->IsIntermediate() && node_set.count(item.second)) {
valid = false; valid = false;
break; break;
} }
......
...@@ -245,6 +245,8 @@ class GraphPatternDetector { ...@@ -245,6 +245,8 @@ class GraphPatternDetector {
void UniquePatterns(std::vector<subgraph_t>* subgraphs); void UniquePatterns(std::vector<subgraph_t>* subgraphs);
// Remove overlapped match subgraphs, when overlapped, keep the previous one. // Remove overlapped match subgraphs, when overlapped, keep the previous one.
// The intermediate PDNodes will be removed, so can't shared by multiple
// patterns.
void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs); void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs);
// Validate whether the intermediate nodes are linked by external nodes. // Validate whether the intermediate nodes are linked by external nodes.
......
...@@ -140,8 +140,9 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ...@@ -140,8 +140,9 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
return node->IsOp() && (node->Name() == "op2" || node->Name() == "op3"); return node->IsOp() && (node->Name() == "op2" || node->Name() == "op3");
}, },
"OP0"); "OP0");
auto* any_var = x.mutable_pattern()->NewNode( auto* any_var = x.mutable_pattern()
[](Node* node) { return node->IsVar(); }, "VAR"); ->NewNode([](Node* node) { return node->IsVar(); }, "VAR")
->AsIntermediate();
auto* any_op1 = x.mutable_pattern()->NewNode( auto* any_op1 = x.mutable_pattern()->NewNode(
[](Node* node) { return node->IsOp(); }, "OP1"); [](Node* node) { return node->IsOp(); }, "OP1");
......
...@@ -13,42 +13,41 @@ ...@@ -13,42 +13,41 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class InferCleanGraphPass : public Pass { class InferCleanGraphPass : public FusePassBase {
public: public:
virtual ~InferCleanGraphPass() {} virtual ~InferCleanGraphPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init("original_graph", graph.get());
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph.get());
auto is_valid_node = [](Node* x) { auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
}; };
std::unordered_set<Node*> invalid_nodes; std::unordered_set<const Node*> invalid_nodes;
int valid_op = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (is_valid_node(node)) { if (is_valid_node(node)) {
invalid_nodes.insert(node); invalid_nodes.insert(node);
} else if (node->IsOp()) {
// Collect all the operators to help tracking number of operators.
++valid_op;
} }
} }
// remove nodes from the graph. GraphSafeRemoveNodes(graph.get(), invalid_nodes);
for (auto* node : invalid_nodes) {
graph->RemoveNode(node);
}
// clean edges. AddStatis(valid_op);
for (auto* node : graph->Nodes()) {
CleanEdges(&node->inputs, invalid_nodes);
CleanEdges(&node->outputs, invalid_nodes);
}
return graph; return graph;
} }
......
...@@ -327,9 +327,20 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -327,9 +327,20 @@ void TestDituRNNPrediction(const std::string &model_path,
LOG(INFO) << "fused " << item.first << " " << item.second; LOG(INFO) << "fused " << item.first << " " << item.second;
} }
ASSERT_TRUE(fuse_statis.count("fc")); int num_ops = 0;
EXPECT_EQ(fuse_statis.at("fc"), 1); for (auto &node :
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 1); analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
if (node->IsFunction()) {
++num_ops;
}
}
LOG(INFO) << "has num ops: " << num_ops;
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
EXPECT_EQ(num_ops,
13); // After graph optimization, only 13 operators exists.
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册