From dfbfbd01e34040fe9d76c69606c3e08ca6e12f24 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Wed, 23 Jun 2021 19:20:39 +0800 Subject: [PATCH] enhance Conv elementwise add2 act fuse pass (#33564) * tmp * pass con_element_add2_act * recover unittests CMakeLists * init pass enhance * fix the attr according to review * repair the attr conv2d * repair axis of elementwise_add * CI-coverage test=allcase * repari some attr * recover batch_norm_act * conv_elementwise_add2_act_fuse --- .../ir/conv_elementwise_add2_act_fuse_pass.cc | 54 +++++++++++++++++++ .../ir/conv_elementwise_add2_act_fuse_pass.h | 1 + 2 files changed, 55 insertions(+) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index e7656171700..f2a295694dc 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -52,6 +52,56 @@ framework::proto::OpDesc PrepareOpDesc( desc.Flush(); return *desc.Proto(); } +ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NHWC", "NCHW"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + // the first elementwise_add-axis needs to be 1, the second has to be -1 + .IsIntIn({1, -1}) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add2_act_fuse"; @@ -66,6 +116,10 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass op compat failed."; + return; + } GET_NODES; auto base_op_desc = *conv_op->Op()->Proto(); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h index e68f57d4ae9..3d5e5788fed 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class ConvElementwiseAdd2ActFusePass : public FusePassBase { public: + ConvElementwiseAdd2ActFusePass(); virtual ~ConvElementwiseAdd2ActFusePass() {} protected: -- GitLab