diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index ac6e22862d6299d193c9baa342c8ce5a6f2c56e6..c89984f384691760a4a9032778cac99c73eede13 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -48,6 +48,60 @@ framework::proto::OpDesc PrepareOpDesc( return *desc.Proto(); } +ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("ResidualData") + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add_act_fuse"; FusePassBase::Init(pattern_name, graph); @@ -63,6 +117,10 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } GET_NODES; auto base_op_desc = *conv_op->Op()->Proto(); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h index 933092c7db7d38d722af9392e71cd0c1797f0eee..d28f212f49e71be92ea9e9d0eff1683fb67c3566 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class ConvElementwiseAddActFusePass : public FusePassBase { public: + ConvElementwiseAddActFusePass(); virtual ~ConvElementwiseAddActFusePass() {} protected: diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index 170b8fb8c80fa78884c3f4f69ebe892bc5b2908c..248a71ede14beb35db0580b879891d5b3b614157 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -29,6 +29,52 @@ namespace ir { GET_IR_NODE(elementwise_add_in_y); \ GET_IR_NODE(elementwise_add_out); +ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("ResidualData") + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); +} + void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add_fuse"; FusePassBase::Init(pattern_name, graph); @@ -44,6 +90,10 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } GET_NODES; auto base_op_desc = *conv_op->Op()->Proto(); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h index 7198a7488e052b5bdbe52d662b903d9f90c51da0..0913dc5c0022714e4013b718ab177862726dc911 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class ConvElementwiseAddFusePass : public FusePassBase { public: + ConvElementwiseAddFusePass(); virtual ~ConvElementwiseAddFusePass() {} protected: