diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index aaae505edde385b5723bdcb1987805b4ce68a5be..c817400056c2132f55b7d2ffc43af42966166eea 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -88,6 +88,13 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { desc->SetAttr("fuse_beta", activation->Op()->GetAttrIfExists("beta")); + if (activation_type() == "hard_sigmoid") { + desc->SetAttr("fuse_alpha", + activation->Op()->GetAttrIfExists("slope")); + desc->SetAttr("fuse_beta", + activation->Op()->GetAttrIfExists("offset")); + } + GraphSafeRemoveNodes(graph, {activation, conv_out}); PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL, @@ -213,6 +220,26 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { .End(); } +Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() { + AddOpCompat(OpCompat("hard_sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // optional, default=0.2 + .AddAttr("slope") + .IsOptional() + .IsType() + .End() + // optional, default=0.5 + .AddAttr("offset") + .IsOptional() + .IsType() + .End(); +} + } // namespace ir } // namespace framework } // namespace paddle @@ -259,3 +286,11 @@ REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass) paddle::framework::compatible::OpVersionComparatorCombination() .LE("conv2d", 1) .EQ("hard_swish", 0)); + +REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass, + paddle::framework::ir::Conv2DHardSigmoidFusePass); +REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("conv2d", 1) + .EQ("hard_sigmoid", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index d22773fb41904afa17832224169f5430b94055c6..eacde101d5a0a75054609cb86dd733d910297e7a 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -72,6 +72,15 @@ class Conv2DHardSwishFusePass : public ConvActivationFusePass { Conv2DHardSwishFusePass(); std::string activation_type() const { return "hard_swish"; } }; +/* + * Fuse Conv and HardSigmoid class + */ +class Conv2DHardSigmoidFusePass : public ConvActivationFusePass { + public: + Conv2DHardSigmoidFusePass(); + std::string activation_type() const { return "hard_sigmoid"; } +}; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc index 453197cda391542f41adcbeab55147b401d242f3..a398e3341698923f0896fab1eeed67ff513de592 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -148,6 +148,9 @@ TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); } TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) { MainTest("hard_swish"); } +TEST(ConvActivationFusePass, conv_hard_sigmoid_fuse_pass) { + MainTest("hard_sigmoid"); +} } // namespace ir } // namespace framework diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 5b49a0d591edd9b8bd9403a2f330a16cc0efe8ec..7d867b59e7d5ba0840f5629f41cd799e160e3b1f 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_relu6_mkldnn_fuse_pass", // "conv_swish_mkldnn_fuse_pass", // "conv_hard_swish_mkldnn_fuse_pass", // + "conv_hard_sigmoid_mkldnn_fuse_pass", // "scale_matmul_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", // "matmul_transpose_reshape_fuse_pass", // diff --git a/paddle/fluid/operators/compat/hard_sigmoid.pbtxt b/paddle/fluid/operators/compat/hard_sigmoid.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c8b66edf2223a3c1a846ae7dade2dee6da2bdd56 --- /dev/null +++ b/paddle/fluid/operators/compat/hard_sigmoid.pbtxt @@ -0,0 +1,17 @@ +type: "hard_sigmoid" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } + attrs { + name: "slope" + type: FLOAT + } + attrs { + name: "offset" + type: FLOAT + } +} diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index cce835e6bc0354a23710874d7acb4f3a6195c1f2..2c03da252d20ca30877184e59a12180672fa680a 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -475,23 +475,25 @@ class ConvMKLDNNHandlerT } // Fusion with ReLU layer is executed through the PostOps feature. Create a // PostOps object and configure it to execute an eltwise relu operation. + constexpr float scale = 1.0f; if (fuse_activation == "relu" || fuse_activation == "leaky_relu") { - constexpr float scale = 1.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, fuse_alpha, fuse_beta); } else if (fuse_activation == "relu6") { - constexpr float scale = 1.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta); } else if (fuse_activation == "swish") { - constexpr float scale = 1.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish, fuse_alpha, fuse_beta); } else if (fuse_activation == "hard_swish") { - constexpr float scale = 1.0f; post_operations.append_eltwise( scale, mkldnn::algorithm::eltwise_hardswish, fuse_alpha, fuse_beta); + } else if (fuse_activation == "hard_sigmoid") { + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_linear, + fuse_alpha, fuse_beta); + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_clip, + 0.0f, 1.0f); } conv_attr.set_post_ops(post_operations); return conv_attr; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py index 11d05f32c4d13b22f67758626f724c8ad3193c4c..cf9b2257553b70bfe7de83ad05c62d2381fac0dd 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py @@ -102,5 +102,14 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest): self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass' +class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest): + def set_params(self): + self.conv_num_filters = 5 + self.conv_filter_size = 5 + self.conv_bias_attr = True + self.act = "hard_sigmoid" + self.pass_name = 'conv_hard_sigmoid_mkldnn_fuse_pass' + + if __name__ == "__main__": unittest.main()