未验证 提交 796c87d5 编写于 作者: Y Yan Chunwei 提交者: GitHub

bugfix/fusion lstm (#13185)

上级 9557cc21
......@@ -77,7 +77,7 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("fc", graph.get());
FusePassBase::Init("fc_fuse", graph.get());
std::unordered_set<Node*> nodes2delete;
......
......@@ -111,6 +111,11 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
return false;
}
}
for (auto& item : pdnodes2nodes_) {
for (auto& n : item.second) {
GetMarkedNodes(const_cast<Graph*>(&graph)).insert(n);
}
}
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty();
......@@ -278,7 +283,7 @@ void GraphPatternDetector::RemoveOverlappedMatch(
for (const auto& subgraph : *subgraphs) {
bool valid = true;
for (auto& item : subgraph) {
if (node_set.count(item.second)) {
if (item.first->IsIntermediate() && node_set.count(item.second)) {
valid = false;
break;
}
......
......@@ -245,6 +245,8 @@ class GraphPatternDetector {
void UniquePatterns(std::vector<subgraph_t>* subgraphs);
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
// The intermediate PDNodes will be removed, so can't shared by multiple
// patterns.
void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs);
// Validate whether the intermediate nodes are linked by external nodes.
......
......@@ -140,8 +140,9 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
return node->IsOp() && (node->Name() == "op2" || node->Name() == "op3");
},
"OP0");
auto* any_var = x.mutable_pattern()->NewNode(
[](Node* node) { return node->IsVar(); }, "VAR");
auto* any_var = x.mutable_pattern()
->NewNode([](Node* node) { return node->IsVar(); }, "VAR")
->AsIntermediate();
auto* any_op1 = x.mutable_pattern()->NewNode(
[](Node* node) { return node->IsOp(); }, "OP1");
......
......@@ -13,42 +13,41 @@
// limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class InferCleanGraphPass : public Pass {
class InferCleanGraphPass : public FusePassBase {
public:
virtual ~InferCleanGraphPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init("original_graph", graph.get());
PADDLE_ENFORCE(graph.get());
auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
};
std::unordered_set<Node*> invalid_nodes;
std::unordered_set<const Node*> invalid_nodes;
int valid_op = 0;
for (auto* node : graph->Nodes()) {
if (is_valid_node(node)) {
invalid_nodes.insert(node);
} else if (node->IsOp()) {
// Collect all the operators to help tracking number of operators.
++valid_op;
}
}
// remove nodes from the graph.
for (auto* node : invalid_nodes) {
graph->RemoveNode(node);
}
GraphSafeRemoveNodes(graph.get(), invalid_nodes);
// clean edges.
for (auto* node : graph->Nodes()) {
CleanEdges(&node->inputs, invalid_nodes);
CleanEdges(&node->outputs, invalid_nodes);
}
AddStatis(valid_op);
return graph;
}
......
......@@ -327,9 +327,20 @@ void TestDituRNNPrediction(const std::string &model_path,
LOG(INFO) << "fused " << item.first << " " << item.second;
}
ASSERT_TRUE(fuse_statis.count("fc"));
EXPECT_EQ(fuse_statis.at("fc"), 1);
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 1);
int num_ops = 0;
for (auto &node :
analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
if (node->IsFunction()) {
++num_ops;
}
}
LOG(INFO) << "has num ops: " << num_ops;
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
EXPECT_EQ(num_ops,
13); // After graph optimization, only 13 operators exists.
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册