diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 2226169e65b03ce3a0d37c026f38f8031828c0ac..8bc9072948f0226f61c24a18b7e6ce6bc801ff85 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -60,6 +60,10 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { if (activation_type() == "relu6") { desc->SetAttr("fuse_alpha", boost::get(activation->Op()->GetAttr("threshold"))); + } else if (activation_type() == "swish") { + // paddle uses beta but mkldnn uses alpha for swish + desc->SetAttr("fuse_alpha", + activation->Op()->GetAttrIfExists("beta")); } else { desc->SetAttr("fuse_alpha", activation->Op()->GetAttrIfExists("alpha")); @@ -95,3 +99,6 @@ REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, paddle::framework::ir::Conv2DReLU6FusePass); + +REGISTER_PASS(conv_swish_mkldnn_fuse_pass, + paddle::framework::ir::Conv2DSwishFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index 7c6dc238a55af2cf54aee587091fdda2c03cc8aa..ac15fc0451285d4d5575dbc08f430625912ac823 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -50,6 +50,13 @@ class Conv2DReLU6FusePass : public ConvActivationFusePass { public: std::string activation_type() const { return "relu6"; } }; +/* + * Fuse Conv and Swish class + */ +class Conv2DSwishFusePass : public ConvActivationFusePass { + public: + std::string activation_type() const { return "swish"; } +}; } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc index ec38788bb4bf59f97c1a7bbbf63d8e389457d7eb..f4155568cf8743eed3c2204b5cf0a4268ee15828 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -40,6 +40,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetAttr("alpha", 0.02f); } else if (type == "relu6") { op->SetAttr("threshold", 6.0f); + } else if (type == "swish") { + op->SetAttr("beta", 1.0f); } } op->SetOutput("Out", outputs); @@ -133,6 +135,7 @@ TEST(ConvActivationFusePass, conv_leaky_relu_fuse_pass) { MainTest("leaky_relu"); } TEST(ConvActivationFusePass, conv_relu6_fuse_pass) { MainTest("relu6"); } +TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); } } // namespace ir } // namespace framework diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index ccd5ded4662458cb3367f52b7179e65584094399..e29a3e3ca2ab680ac5147b7b988271b2388d07a6 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -196,6 +196,7 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_relu_mkldnn_fuse_pass", // "conv_leaky_relu_mkldnn_fuse_pass", // "conv_relu6_mkldnn_fuse_pass", // + "conv_swish_mkldnn_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass" })) { diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 71f67466cbb8e95313c17bd98c19dbfb3c147c69..124470f0c3c2de6fb30398e5a9f82989ea60e4e8 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -589,6 +589,13 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "Input of Swish operator"); AddOutput("Out", "Output of Swish operator"); AddAttr("beta", "Constant beta of swish operator").SetDefault(1.0f); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddComment(R"DOC( Swish Activation Operator. diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 3b367c9a5bcd48a25b82869affb4ddd6ff699ca4..b68cb325a7b7cf769c56361f92c539bdaf7b5a4d 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -73,8 +73,13 @@ void eltwise_forward(const framework::ExecutionContext &ctx, const auto *x = ctx.Input("X"); auto *y = ctx.Output("Out"); - const T alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; - const T beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; + T alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; + T beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; + + // paddle uses beta but mkldnn uses alpha for swish + if (algorithm == mkldnn::algorithm::eltwise_swish) { + std::swap(alpha, beta); + } PADDLE_ENFORCE( x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4, @@ -112,8 +117,13 @@ void eltwise_grad(const framework::ExecutionContext &ctx, const auto *diff_y = ctx.Input(framework::GradVarName("Out")); auto *diff_x = ctx.Output(framework::GradVarName("X")); - const T alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; - const T beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; + T alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; + T beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; + + // paddle uses beta but mkldnn uses alpha for swish + if (algorithm == mkldnn::algorithm::eltwise_swish) { + std::swap(alpha, beta); + } auto diff_dst_tz = framework::vectorize(diff_y->dims()); @@ -162,6 +172,10 @@ template using ReluMKLDNNFunctor = MKLDNNActivationFunc; +template +using SwishMKLDNNFunctor = + MKLDNNActivationFunc; + template using TanhMKLDNNFunctor = MKLDNNActivationFunc; @@ -178,6 +192,10 @@ template using ReluMKLDNNGradFunctor = MKLDNNActivationGradFunc; +template +using SwishMKLDNNGradFunctor = + MKLDNNActivationGradFunc; + template using TanhMKLDNNGradFunctor = MKLDNNActivationGradFunc; @@ -204,6 +222,7 @@ namespace ops = paddle::operators; #define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \ __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \ __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \ + __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \ __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \ __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \ __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index f8ee9b96398a1a174052eede9eca1aaa6ca7ff1c..25b285ccc83716b1673c66004a264dfde5ed9484 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -978,13 +978,15 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { constexpr float scale = 1.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, fuse_alpha, fuse_beta); - } - - if (fuse_activation == "relu6") { + } else if (fuse_activation == "relu6") { constexpr float scale = 1.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta); + } else if (fuse_activation == "swish") { + constexpr float scale = 1.0f; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish, + fuse_alpha, fuse_beta); } conv_attr.set_post_ops(post_operations); return conv_attr; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py index c988e6275ffd9ec035d1f6023c330bcf1d4307fc..da1a6ee9669cff288cfe3e5117bba8a67be82cf5 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py @@ -16,9 +16,10 @@ from __future__ import print_function import unittest import numpy as np +from scipy.special import expit import paddle.fluid.core as core from paddle.fluid.tests.unittests.op_test import OpTest -from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu +from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu, TestSwish from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd @@ -111,6 +112,29 @@ class TestMKLDNNAbsDim2(TestAbs): ['X'], 'Out', max_relative_error=0.007, check_dygraph=False) +class TestMKLDNNSwishDim2(TestSwish): + def setUp(self): + super(TestMKLDNNSwishDim2, self).setUp() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + beta = 2.3 + out = x * expit(beta * x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + self.attrs = {"use_mkldnn": True, "beta": beta} + + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_output() + + def test_check_grad(self): + if self.dtype == np.float16: + return + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_grad(['X'], 'Out') + + class TestMKLDNNReluDim4(TestRelu): def setUp(self): super(TestMKLDNNReluDim4, self).setUp() @@ -228,6 +252,29 @@ class TestMKLDNNAbsDim4(TestAbs): ['X'], 'Out', max_relative_error=0.007, check_dygraph=False) +class TestMKLDNNSwishDim4(TestSwish): + def setUp(self): + super(TestMKLDNNSwishDim4, self).setUp() + + x = np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") + beta = 2.3 + out = x * expit(beta * x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + self.attrs = {"use_mkldnn": True, "beta": beta} + + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_output() + + def test_check_grad(self): + if self.dtype == np.float16: + return + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_grad(['X'], 'Out') + + # Check if primitives already exist in backward class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase): def setUp(self):