diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index ff86bdb8fa86faba363d43ca1831d5ffe800a8c7..6d9611ebd13931a58215f4638c0886881bd38c51 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -105,6 +105,22 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { .AddOutput("Out") .IsTensor() .End(); + + AddOpCompat(OpCompat("sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("tanh")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); } void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { @@ -188,4 +204,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass) .LE("conv2d", 1) .LE("elementwise_add", 1) .EQ("relu", 0) + .EQ("sigmoid", 0) + .EQ("tanh", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index f67e83bc10171b7f1ea5c9db43f215fbf284f568..47e2c5e380bcbfb6410e81cba9b2f7896f4018ae 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -102,6 +102,22 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() { .AddOutput("Out") .IsTensor() .End(); + + AddOpCompat(OpCompat("sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("tanh")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); } void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { @@ -170,4 +186,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass) .LE("conv2d", 1) .LE("elementwise_add", 1) .EQ("relu", 0) + .EQ("sigmoid", 0) + .EQ("tanh", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 154df498e7d133cfc52b7dc34073663941497a58..f0949cb9dfbd2547826c6c910717808a44809bb7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2324,7 +2324,8 @@ PDNode *patterns::PriorBox::operator()() { return boxes_var; } -std::unordered_set conv_act_set({"identity", "relu"}); +std::unordered_set conv_act_set( + {"identity", "relu", "sigmoid", "tanh"}); PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { conv_in->AsInput(); diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index 5e96ca140274dfda8c29484df641d9f964fe41fa..2ee63c93642218fddc331f4205d8b07575e921ec 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -544,9 +544,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { namespace ops = paddle::operators; #if CUDNN_VERSION >= 7100 -REGISTER_OP_CUDA_KERNEL(conv2d_fusion, - ops::CUDNNConvFusionOpKernel, - ops::CUDNNConvFusionOpKernel); +REGISTER_OP_CUDA_KERNEL( + conv2d_fusion, + ops::CUDNNConvFusionOpKernel, + ops::CUDNNConvFusionOpKernel, + ops::CUDNNConvFusionOpKernel); #endif #ifdef PADDLE_WITH_HIP REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel);