diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index e7656171700b4ff7dda665b985521902518d7720..f2a295694dcb962536d4956370837a7346a4bc0e 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -52,6 +52,56 @@ framework::proto::OpDesc PrepareOpDesc( desc.Flush(); return *desc.Proto(); } +ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NHWC", "NCHW"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + // the first elementwise_add-axis needs to be 1, the second has to be -1 + .IsIntIn({1, -1}) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add2_act_fuse"; @@ -66,6 +116,10 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass op compat failed."; + return; + } GET_NODES; auto base_op_desc = *conv_op->Op()->Proto(); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h index e68f57d4ae998203c6f34aee7cca11d69a5e6d3f..3d5e5788fed2d002a63a0a6149b06be1f54e015a 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class ConvElementwiseAdd2ActFusePass : public FusePassBase { public: + ConvElementwiseAdd2ActFusePass(); virtual ~ConvElementwiseAdd2ActFusePass() {} protected: