From 899792c1519977c9a723c4ae8b59e15b532994d6 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 23 Jun 2021 15:56:13 +0800 Subject: [PATCH] pass enhance (#33710) --- .../ir/conv_elementwise_add_act_fuse_pass.cc | 58 +++++++++++++++++++ .../ir/conv_elementwise_add_act_fuse_pass.h | 1 + .../ir/conv_elementwise_add_fuse_pass.cc | 50 ++++++++++++++++ .../ir/conv_elementwise_add_fuse_pass.h | 1 + 4 files changed, 110 insertions(+) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index ac6e22862d..c89984f384 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -48,6 +48,60 @@ framework::proto::OpDesc PrepareOpDesc( return *desc.Proto(); } +ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("ResidualData") + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add_act_fuse"; FusePassBase::Init(pattern_name, graph); @@ -63,6 +117,10 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } GET_NODES; auto base_op_desc = *conv_op->Op()->Proto(); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h index 933092c7db..d28f212f49 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class ConvElementwiseAddActFusePass : public FusePassBase { public: + ConvElementwiseAddActFusePass(); virtual ~ConvElementwiseAddActFusePass() {} protected: diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index 170b8fb8c8..248a71ede1 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -29,6 +29,52 @@ namespace ir { GET_IR_NODE(elementwise_add_in_y); \ GET_IR_NODE(elementwise_add_out); +ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("ResidualData") + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); +} + void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add_fuse"; FusePassBase::Init(pattern_name, graph); @@ -44,6 +90,10 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } GET_NODES; auto base_op_desc = *conv_op->Op()->Proto(); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h index 7198a7488e..0913dc5c00 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class ConvElementwiseAddFusePass : public FusePassBase { public: + ConvElementwiseAddFusePass(); virtual ~ConvElementwiseAddFusePass() {} protected: -- GitLab