From 4c269ccb5965c35c3285a1a79db042d6aabc6182 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Wed, 6 Jul 2022 11:29:45 +0800 Subject: [PATCH] [Paddle Inference] Add conv_elementwise_act. (#43871) * conv_fusion --- .../ir/conv_elementwise_add2_act_fuse_pass.cc | 18 ++++++++++++++++++ .../ir/conv_elementwise_add_act_fuse_pass.cc | 18 ++++++++++++++++++ .../framework/ir/graph_pattern_detector.cc | 3 ++- paddle/fluid/operators/fused/conv_fusion_op.cu | 8 +++++--- 4 files changed, 43 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index ff86bdb8fa8..6d9611ebd13 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -105,6 +105,22 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { .AddOutput("Out") .IsTensor() .End(); + + AddOpCompat(OpCompat("sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("tanh")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); } void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { @@ -188,4 +204,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass) .LE("conv2d", 1) .LE("elementwise_add", 1) .EQ("relu", 0) + .EQ("sigmoid", 0) + .EQ("tanh", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index f67e83bc101..47e2c5e380b 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -102,6 +102,22 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() { .AddOutput("Out") .IsTensor() .End(); + + AddOpCompat(OpCompat("sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("tanh")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); } void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { @@ -170,4 +186,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass) .LE("conv2d", 1) .LE("elementwise_add", 1) .EQ("relu", 0) + .EQ("sigmoid", 0) + .EQ("tanh", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 154df498e7d..f0949cb9dfb 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2324,7 +2324,8 @@ PDNode *patterns::PriorBox::operator()() { return boxes_var; } -std::unordered_set conv_act_set({"identity", "relu"}); +std::unordered_set conv_act_set( + {"identity", "relu", "sigmoid", "tanh"}); PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { conv_in->AsInput(); diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index 5e96ca14027..2ee63c93642 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -544,9 +544,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { namespace ops = paddle::operators; #if CUDNN_VERSION >= 7100 -REGISTER_OP_CUDA_KERNEL(conv2d_fusion, - ops::CUDNNConvFusionOpKernel, - ops::CUDNNConvFusionOpKernel); +REGISTER_OP_CUDA_KERNEL( + conv2d_fusion, + ops::CUDNNConvFusionOpKernel, + ops::CUDNNConvFusionOpKernel, + ops::CUDNNConvFusionOpKernel); #endif #ifdef PADDLE_WITH_HIP REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel); -- GitLab