未验证 提交 c3008e75 编写于 作者: X xiaoting 提交者: GitHub

Add seqconv pass enhance (#33455)

上级 cca44c1d
...@@ -27,16 +27,65 @@ namespace paddle { ...@@ -27,16 +27,65 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
SeqConvEltAddReluFusePass::SeqConvEltAddReluFusePass() {
AddOpCompat(OpCompat("sequence_conv"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("PaddingData")
.IsOptional()
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("contextLength")
.IsNumGT(0)
.End()
.AddAttr("contextStart") // the contextStart attribute can be negative,
// unconstrained
.End()
.AddAttr("contextStride")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
class Node; class Node;
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X")) PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope_, "X"))
->assert_is_op_input("sequence_conv") ->assert_is_op_input("sequence_conv")
->assert_var_not_persistable(); ->assert_var_not_persistable();
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope); patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope_);
fuse_pattern(x); fuse_pattern(x);
// Create New OpDesc // Create New OpDesc
...@@ -70,6 +119,10 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { ...@@ -70,6 +119,10 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle SeqConv EltAdd Relu fuse"; VLOG(4) << "handle SeqConv EltAdd Relu fuse";
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
...@@ -89,14 +142,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { ...@@ -89,14 +142,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
}; };
gpd(graph, handler); gpd(graph, handler);
return fusion_count;
}
void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph, name_scope_, param_scope());
AddStatis(fusion_count); AddStatis(fusion_count);
} }
......
...@@ -28,6 +28,7 @@ class Graph; ...@@ -28,6 +28,7 @@ class Graph;
class SeqConvEltAddReluFusePass : public FusePassBase { class SeqConvEltAddReluFusePass : public FusePassBase {
public: public:
SeqConvEltAddReluFusePass();
virtual ~SeqConvEltAddReluFusePass() {} virtual ~SeqConvEltAddReluFusePass() {}
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册