From 536907196cb1be3b2faea5652e09f7a165492320 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Tue, 2 Nov 2021 11:25:28 +0100 Subject: [PATCH] [Need review] Added conv + hard_sigmoid oneDNN fuse pass (#36869) * added conv + hard_sigmoid fuse pass * Removed IsOptional() statements * Reverted removing optional --- .../conv_activation_mkldnn_fuse_pass.cc | 35 +++++++++++++++++++ .../mkldnn/conv_activation_mkldnn_fuse_pass.h | 9 +++++ ...conv_activation_mkldnn_fuse_pass_tester.cc | 3 ++ .../inference/api/paddle_pass_builder.cc | 1 + .../fluid/operators/compat/hard_sigmoid.pbtxt | 17 +++++++++ .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 10 +++--- .../test_mkldnn_conv_activation_fuse_pass.py | 9 +++++ 7 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/operators/compat/hard_sigmoid.pbtxt diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index aaae505edde..c817400056c 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -88,6 +88,13 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { desc->SetAttr("fuse_beta", activation->Op()->GetAttrIfExists("beta")); + if (activation_type() == "hard_sigmoid") { + desc->SetAttr("fuse_alpha", + activation->Op()->GetAttrIfExists("slope")); + desc->SetAttr("fuse_beta", + activation->Op()->GetAttrIfExists("offset")); + } + GraphSafeRemoveNodes(graph, {activation, conv_out}); PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL, @@ -213,6 +220,26 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { .End(); } +Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() { + AddOpCompat(OpCompat("hard_sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // optional, default=0.2 + .AddAttr("slope") + .IsOptional() + .IsType() + .End() + // optional, default=0.5 + .AddAttr("offset") + .IsOptional() + .IsType() + .End(); +} + } // namespace ir } // namespace framework } // namespace paddle @@ -259,3 +286,11 @@ REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass) paddle::framework::compatible::OpVersionComparatorCombination() .LE("conv2d", 1) .EQ("hard_swish", 0)); + +REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass, + paddle::framework::ir::Conv2DHardSigmoidFusePass); +REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("conv2d", 1) + .EQ("hard_sigmoid", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index d22773fb419..eacde101d5a 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -72,6 +72,15 @@ class Conv2DHardSwishFusePass : public ConvActivationFusePass { Conv2DHardSwishFusePass(); std::string activation_type() const { return "hard_swish"; } }; +/* + * Fuse Conv and HardSigmoid class + */ +class Conv2DHardSigmoidFusePass : public ConvActivationFusePass { + public: + Conv2DHardSigmoidFusePass(); + std::string activation_type() const { return "hard_sigmoid"; } +}; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc index 453197cda39..a398e334169 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -148,6 +148,9 @@ TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); } TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) { MainTest("hard_swish"); } +TEST(ConvActivationFusePass, conv_hard_sigmoid_fuse_pass) { + MainTest("hard_sigmoid"); +} } // namespace ir } // namespace framework diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 5b49a0d591e..7d867b59e7d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_relu6_mkldnn_fuse_pass", // "conv_swish_mkldnn_fuse_pass", // "conv_hard_swish_mkldnn_fuse_pass", // + "conv_hard_sigmoid_mkldnn_fuse_pass", // "scale_matmul_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", // "matmul_transpose_reshape_fuse_pass", // diff --git a/paddle/fluid/operators/compat/hard_sigmoid.pbtxt b/paddle/fluid/operators/compat/hard_sigmoid.pbtxt new file mode 100644 index 00000000000..c8b66edf222 --- /dev/null +++ b/paddle/fluid/operators/compat/hard_sigmoid.pbtxt @@ -0,0 +1,17 @@ +type: "hard_sigmoid" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } + attrs { + name: "slope" + type: FLOAT + } + attrs { + name: "offset" + type: FLOAT + } +} diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index cce835e6bc0..2c03da252d2 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -475,23 +475,25 @@ class ConvMKLDNNHandlerT } // Fusion with ReLU layer is executed through the PostOps feature. Create a // PostOps object and configure it to execute an eltwise relu operation. + constexpr float scale = 1.0f; if (fuse_activation == "relu" || fuse_activation == "leaky_relu") { - constexpr float scale = 1.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, fuse_alpha, fuse_beta); } else if (fuse_activation == "relu6") { - constexpr float scale = 1.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta); } else if (fuse_activation == "swish") { - constexpr float scale = 1.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish, fuse_alpha, fuse_beta); } else if (fuse_activation == "hard_swish") { - constexpr float scale = 1.0f; post_operations.append_eltwise( scale, mkldnn::algorithm::eltwise_hardswish, fuse_alpha, fuse_beta); + } else if (fuse_activation == "hard_sigmoid") { + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_linear, + fuse_alpha, fuse_beta); + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_clip, + 0.0f, 1.0f); } conv_attr.set_post_ops(post_operations); return conv_attr; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py index 11d05f32c4d..cf9b2257553 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py @@ -102,5 +102,14 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest): self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass' +class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest): + def set_params(self): + self.conv_num_filters = 5 + self.conv_filter_size = 5 + self.conv_bias_attr = True + self.act = "hard_sigmoid" + self.pass_name = 'conv_hard_sigmoid_mkldnn_fuse_pass' + + if __name__ == "__main__": unittest.main() -- GitLab