未验证 提交 72af57bb 编写于 作者: W Wangzheee 提交者: GitHub

[pass_enhance] : seq_concat_fc_fuse_pass (#33961)

上级 740f4e30
...@@ -174,6 +174,91 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { ...@@ -174,6 +174,91 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
return fc_out; return fc_out;
} }
SeqConcatFcFusePass::SeqConcatFcFusePass() {
AddOpCompat(OpCompat("sequence_expand"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("ref_level")
.IsNumEQ(0)
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("seq_concat_fc_fuse", graph); FusePassBase::Init("seq_concat_fc_fuse", graph);
GraphPatternDetector detector; GraphPatternDetector detector;
...@@ -193,6 +278,10 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -193,6 +278,10 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
detector(graph, [&](const GraphPatternDetector::subgraph_t& subgraph, detector(graph, [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "seq_concat_fc_fuse_pass in op compat failed.";
return;
}
VLOG(4) << "get one concat pattern"; VLOG(4) << "get one concat pattern";
// fc // fc
GET_NODE(fc_w, detector.pattern()); GET_NODE(fc_w, detector.pattern());
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -26,6 +24,7 @@ class Graph; ...@@ -26,6 +24,7 @@ class Graph;
class SeqConcatFcFusePass : public FusePassBase { class SeqConcatFcFusePass : public FusePassBase {
public: public:
SeqConcatFcFusePass();
virtual ~SeqConcatFcFusePass() {} virtual ~SeqConcatFcFusePass() {}
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册