diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc index 9337a67651ee3c16604bfb12314a6d6bb8dce71c..9fa951920f45a311314832cdaa0e61b5319a8551 100644 --- a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc @@ -27,16 +27,65 @@ namespace paddle { namespace framework { namespace ir { +SeqConvEltAddReluFusePass::SeqConvEltAddReluFusePass() { + AddOpCompat(OpCompat("sequence_conv")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("PaddingData") + .IsOptional() + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("contextLength") + .IsNumGT(0) + .End() + .AddAttr("contextStart") // the contextStart attribute can be negative, + // unconstrained + .End() + .AddAttr("contextStride") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + class Node; -int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { +void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); - PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X")) + PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope_, "X")) ->assert_is_op_input("sequence_conv") ->assert_var_not_persistable(); - patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope); + patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope_); fuse_pattern(x); // Create New OpDesc @@ -70,6 +119,10 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "handle SeqConv EltAdd Relu fuse"; GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern); @@ -89,14 +142,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { }; gpd(graph, handler); - - return fusion_count; -} - -void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const { - FusePassBase::Init(name_scope_, graph); - - int fusion_count = BuildFusion(graph, name_scope_, param_scope()); AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h index 6f623625f51d8217370f2eabfb6820eebeb6e07a..fe06002251ae2adefc64c431446f90aad5ea85b4 100644 --- a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h @@ -28,6 +28,7 @@ class Graph; class SeqConvEltAddReluFusePass : public FusePassBase { public: + SeqConvEltAddReluFusePass(); virtual ~SeqConvEltAddReluFusePass() {} protected: