diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 0d0151fb738dba2870c38339d20b5cfbcfc77626..6370d3380361c0c3b3434bfb9b73f429174f4cfa 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -230,7 +230,15 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { .IsType() .End(); } - +Conv2DMishFusePass::Conv2DMishFusePass() { + AddOpCompat(OpCompat("mish")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() { AddOpCompat(OpCompat("hard_sigmoid")) .AddInput("X") @@ -311,6 +319,14 @@ REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass) .LE("conv2d", 1) .EQ("hard_swish", 0)); +REGISTER_PASS(conv_mish_mkldnn_fuse_pass, + paddle::framework::ir::Conv2DMishFusePass); +REGISTER_PASS_CAPABILITY(conv_mish_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("conv2d", 1) + .EQ("mish", 1)); + REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass, paddle::framework::ir::Conv2DHardSigmoidFusePass); REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index b8279e48386c7de3492013ee326af839c76e5efa..1a3a3232ddee244d85ce50f61b66b625f456edf9 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -72,6 +72,14 @@ class Conv2DHardSwishFusePass : public ConvActivationFusePass { Conv2DHardSwishFusePass(); std::string activation_type() const { return "hard_swish"; } }; +/* + * Fuse Conv and Mish class + */ +class Conv2DMishFusePass : public ConvActivationFusePass { + public: + Conv2DMishFusePass(); + std::string activation_type() const { return "mish"; } +}; /* * Fuse Conv and HardSigmoid class */ diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc index a398e3341698923f0896fab1eeed67ff513de592..1fefab805b1d3620e3f8b966ac77d2f9c10b70fa 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -148,6 +148,7 @@ TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); } TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) { MainTest("hard_swish"); } +TEST(ConvActivationFusePass, conv_mish_fuse_pass) { MainTest("mish"); } TEST(ConvActivationFusePass, conv_hard_sigmoid_fuse_pass) { MainTest("hard_sigmoid"); } diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc index 093fd5ec538db1791441d1aa213644a72c89516e..7fc8806452b883040fca1e71ba785583429f6cf3 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc @@ -25,7 +25,7 @@ namespace ir { using string::PrettyLogDetail; void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { - std::vector act_types = {"gelu", "tanh", "sigmoid", + std::vector act_types = {"gelu", "tanh", "sigmoid", "mish", "hard_swish"}; for (std::string act_type : act_types) FuseFCAct(graph, act_type); @@ -99,5 +99,6 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass) .LE("fc", 0) .LE("gelu", 0) .LE("sigmoid", 0) + .LE("mish", 1) .LE("hard_swish", 0) .LE("tanh", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h index 81294dd568926f0c4e86c597f3f82f7b8b13cb62..23f4296b98bcabab17c896a7ea0c80f72e358e06 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h @@ -27,8 +27,8 @@ namespace ir { * \brief Fuse the FC and activation operators into single OneDNN's * FC with post-op. * - * \note Currently only GeLU, hardswish, sigmoid and tanh are supported as an - * activation function. + * \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported + * as an activation function. */ class FuseFCActOneDNNPass : public FusePassBase { public: diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc index 38f87f4428d8a702663a2627348fe9cbf0318205..59d81cb86474d5e17ba29dd6c72581ff8c13b41d 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc @@ -201,6 +201,36 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { } } +TEST(FuseFCActOneDNNPass, FuseWithMish) { + auto prog = + test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + test::CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + test::CreateOp(&prog, "mish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + constexpr int removed_nodes_count = 2; + + EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x", + "act_y", removed_nodes_count)); + EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"mish", 0}})); + + for (const auto* node : graph.Nodes()) { + if (node->IsOp() && node->Op()->Type() == "fc") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("activation_type")); + auto act_type = + BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); + EXPECT_EQ(act_type.compare("mish"), 0); + } + } +} + TEST(FuseFCActOneDNNPass, FuseWithHardSwish) { auto prog = test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 75b7a4ea155f0da1cc2d39dc75bec013c2c61683..57f90c7cc4a881ef7c9ae3c98342e907dd7c027f 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -252,6 +252,7 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_relu6_mkldnn_fuse_pass", // "conv_swish_mkldnn_fuse_pass", // "conv_hard_swish_mkldnn_fuse_pass", // + "conv_mish_mkldnn_fuse_pass", // "conv_hard_sigmoid_mkldnn_fuse_pass", // // TODO(baoachun) fix int8 accuracy "conv_gelu_mkldnn_fuse_pass", diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 0cb074beb60d7998104579f3fecaf12f3bb828c7..dec2fa1836081c5160b2db333815ab24921186ae 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -237,6 +237,10 @@ template using HardSwishMKLDNNFunctor = MKLDNNActivationFunc; +template +using MishMKLDNNFunctor = + MKLDNNActivationFunc; + template using SigmoidMKLDNNFunctor = MKLDNNActivationFunc; @@ -274,6 +278,10 @@ template using HardSwishMKLDNNGradFunctor = MKLDNNActivationGradFunc; +template +using MishMKLDNNGradFunctor = + MKLDNNActivationGradFunc; + template using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc< T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>; @@ -341,6 +349,8 @@ REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor, SigmoidMKLDNNGradUseOutFunctor); REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradUseOutFunctor); +REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(mish, MishMKLDNNFunctor, + MishMKLDNNGradFunctor); namespace ops = paddle::operators; REGISTER_OP_KERNEL( diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 44289015bc7c4ba98b75a5fb1444afce98b585dc..68e2a7c8a91bb232fb479942d307679137b6172a 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -516,6 +516,10 @@ class ConvMKLDNNHandlerT post_operations.append_eltwise(activation_scale, dnnl::algorithm::eltwise_hardswish, fuse_alpha, fuse_beta); + } else if (fuse_activation == "mish") { + post_operations.append_eltwise(activation_scale, + dnnl::algorithm::eltwise_mish, fuse_alpha, + fuse_beta); } else if (fuse_activation == "hard_sigmoid") { post_operations.append_eltwise(activation_scale, dnnl::algorithm::eltwise_linear, diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index f6a6c6940a79d32434d830c3cc3a8f6f2166e25d..153b0be6dad8f0385fee1c2b2fd84cfe1128777a 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -496,6 +496,11 @@ class FCPrimitiveFactory { constexpr float beta = 0.0f; post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_logistic, alpha, beta); + } else if (ctx.Attr("activation_type") == "mish") { + constexpr float alpha = 0.0f; + constexpr float beta = 0.0f; + post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_mish, + alpha, beta); } else if (ctx.Attr("activation_type") == "hard_swish") { constexpr float alpha = 0.0f; constexpr float beta = 0.0f; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py index cf9b2257553b70bfe7de83ad05c62d2381fac0dd..56cb0748a232b7ec3a164771768d93b944a0096b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py @@ -102,6 +102,15 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest): self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass' +class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest): + def set_params(self): + self.conv_num_filters = 5 + self.conv_filter_size = 5 + self.conv_bias_attr = True + self.act = "mish" + self.pass_name = 'conv_mish_mkldnn_fuse_pass' + + class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest): def set_params(self): self.conv_num_filters = 5 diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py index 5d759e4ae28e86704441461602126c33e92ad842..66bcca51bed1d55ca8e1bb9e77cb768593aa130c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py @@ -134,5 +134,27 @@ class FCHardSwishOneDnnFusePassTest(InferencePassTest): self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) +class FCMishOneDnnFusePassTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 128, 768], dtype="float32") + fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2) + mish_out = fluid.layers.mish(fc_out) + + self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")} + + self.fetch_list = [mish_out] + self.enable_mkldnn = True + + def set_params(self): + self.pass_name = "fc_act_mkldnn_fuse_pass" + + def test_check_output(self): + self.check_output() + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_bf16_mkldnn_op.py index c421a6d117e458689aa019dd8671cf51c7380606..8e0fdf76459bd7adc427bd0b0279945ab3c84ca3 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_bf16_mkldnn_op.py @@ -148,5 +148,19 @@ class TestMKLDNNReluBF16Op(MKLDNNBF16ActivationOp, TestActivation): return dout +class TestMKLDNNMishBF16Op(MKLDNNBF16ActivationOp, TestActivation): + def config(self): + self.op_type = "mish" + + def op_forward(self, x): + return x * np.tanh(np.log(1 + np.exp(x))) + + def op_grad(self, dout, x): + omega = np.exp(3 * x) + 4 * np.exp(2 * x) + np.exp(x) * (4 * x + 6 + ) + 4 * (x + 1) + delta = np.exp(2 * x) + 2 * np.exp(x) + 2 + return dout * ((np.exp(x) * omega) / delta**2) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py index 8af2101346fec258e58680edc84bd0f2871e0d31..e2d50fc853887eeda86af75f6cbc6f3cc7a662cc 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py @@ -315,6 +315,19 @@ class TestMKLDNNHardSwishDim4(TestHardSwish): self.dtype = np.float32 +class TestMKLDNNMish(TestActivation): + def setUp(self): + self.op_type = "mish" + self.dtype = np.float32 + + x = np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype(self.dtype) + out = x * np.tanh(np.log(1 + np.exp(x))) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + self.attrs = {"use_mkldnn": True} + + class TestMKLDNNSigmoidDim4(TestSigmoid): def setUp(self): super(TestMKLDNNSigmoidDim4, self).setUp()