From c3008e751e3f20ccfc3eb1945d590e402c2f674e Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Fri, 18 Jun 2021 14:10:28 +0800 Subject: [PATCH] Add seqconv pass enhance (#33455) --- .../ir/seqconv_eltadd_relu_fuse_pass.cc | 67 ++++++++++++++++--- .../ir/seqconv_eltadd_relu_fuse_pass.h | 1 + 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc index 9337a67651e..9fa951920f4 100644 --- a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc @@ -27,16 +27,65 @@ namespace paddle { namespace framework { namespace ir { +SeqConvEltAddReluFusePass::SeqConvEltAddReluFusePass() { + AddOpCompat(OpCompat("sequence_conv")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("PaddingData") + .IsOptional() + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("contextLength") + .IsNumGT(0) + .End() + .AddAttr("contextStart") // the contextStart attribute can be negative, + // unconstrained + .End() + .AddAttr("contextStride") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + class Node; -int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { +void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); - PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X")) + PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope_, "X")) ->assert_is_op_input("sequence_conv") ->assert_var_not_persistable(); - patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope); + patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope_); fuse_pattern(x); // Create New OpDesc @@ -70,6 +119,10 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "handle SeqConv EltAdd Relu fuse"; GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern); @@ -89,14 +142,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { }; gpd(graph, handler); - - return fusion_count; -} - -void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const { - FusePassBase::Init(name_scope_, graph); - - int fusion_count = BuildFusion(graph, name_scope_, param_scope()); AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h index 6f623625f51..fe06002251a 100644 --- a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h @@ -28,6 +28,7 @@ class Graph; class SeqConvEltAddReluFusePass : public FusePassBase { public: + SeqConvEltAddReluFusePass(); virtual ~SeqConvEltAddReluFusePass() {} protected: -- GitLab