diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 453cfb85554ec53549bcbfd1f9be4566deb47d54..a073cbe9771172f68b6f784cedb6de04dcfde25c 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,110 +14,109 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" -#include - #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace framework { -class OpDesc; -} // namespace framework -} // namespace paddle +#include "paddle/fluid/string/pretty_log.h" namespace paddle { namespace framework { namespace ir { -class Graph; +using string::PrettyLogDetail; + +void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { + std::vector act_types = { + "relu", "mish", "swish", "sqrt", "hard_swish", "sigmoid", "abs", + "gelu", "relu6", "clip", "tanh", "hard_sigmoid", "leaky_relu"}; + + std::vector conv_types = {"conv2d"}; + + for (const auto& conv_type : conv_types) + for (auto& act_type : act_types) { + std::unordered_map attrs_map; + + if (act_type == "swish") + attrs_map.emplace("beta", "fuse_alpha"); + else if (act_type == "relu6") + attrs_map.emplace("threshold", "fuse_alpha"); + else if (act_type == "hard_sigmoid") { + attrs_map.emplace("slope", "fuse_alpha"); + attrs_map.emplace("offset", "fuse_beta"); + } else if (act_type == "clip") { + attrs_map.emplace("min", "fuse_alpha"); + attrs_map.emplace("max", "fuse_beta"); + } else { + attrs_map.emplace("alpha", "fuse_alpha"); + attrs_map.emplace("beta", "fuse_beta"); + } + FuseConvAct(graph, conv_type, act_type, attrs_map); + } +} -void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { +void ConvActivationMkldnnFusePass::FuseConvAct( + Graph* graph, const std::string& conv_type, std::string& act_type, + const std::unordered_map& attrs_map) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - FusePassBase::Init("conv_activation_mkldnn_fuse", graph); + FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; auto* conv_input = gpd.mutable_pattern() ->NewNode("conv_activation_mkldnn_fuse/conv_input") ->AsInput() - ->assert_is_op_input(conv_type(), "Input"); - patterns::ConvActivation conv_activation_pattern( - gpd.mutable_pattern(), "conv_activation_mkldnn_fuse"); - conv_activation_pattern(conv_input, conv_type(), activation_type()); + ->assert_is_op_input(conv_type, "Input"); + patterns::ConvActivation conv_act_pattern(gpd.mutable_pattern(), + "conv_activation_mkldnn_fuse"); + conv_act_pattern(conv_input, conv_type, act_type); int found_conv_activation_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse"; + VLOG(4) << "handle " + conv_type + "+" + act_type + " fuse"; if (!IsCompat(subgraph, g)) { LOG(WARNING) << "conv_activation_mkldnn_fuse_pass op compat failed."; return; } - GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, - conv_activation_pattern); // Filter - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, - conv_activation_pattern); // tmp - GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_activation_pattern); // CONV op - GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, - conv_activation_pattern); // Out - GET_IR_NODE_FROM_SUBGRAPH(activation, activation, - conv_activation_pattern); // Activation op - - // Transform Conv node into ConvActivation node. - OpDesc* desc = conv->Op(); - desc->SetOutput("Output", - std::vector({activation_out->Name()})); - - if (activation_type() == "gelu" && - activation->Op()->HasAttr("approximate")) { - bool approximate = - BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate")); - std::string type = approximate ? "_tanh" : "_erf"; - desc->SetAttr("fuse_activation", "gelu" + type); - } else { - desc->SetAttr("fuse_activation", activation_type()); - } - // MKLDNN ops use alpha and beta as activation parameters but paddle ops are - // not generalized - if (activation_type() == "relu6") { - desc->SetAttr( - "fuse_alpha", - BOOST_GET_CONST(float, activation->Op()->GetAttr("threshold"))); - } else if (activation_type() == "swish") { - // paddle uses beta but mkldnn uses alpha for swish - desc->SetAttr("fuse_alpha", - activation->Op()->GetAttrIfExists("beta")); - } else { - desc->SetAttr("fuse_alpha", - activation->Op()->GetAttrIfExists("alpha")); - } - desc->SetAttr("fuse_beta", - activation->Op()->GetAttrIfExists("beta")); - - if (activation_type() == "hard_sigmoid") { - desc->SetAttr("fuse_alpha", - activation->Op()->GetAttrIfExists("slope")); - desc->SetAttr("fuse_beta", - activation->Op()->GetAttrIfExists("offset")); + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, conv_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(activation, activation, conv_act_pattern); + + OpDesc* conv_op = conv->Op(); + OpDesc* act_op = activation->Op(); + + for (const auto& attrs : attrs_map) { + if (act_op->HasAttr(attrs.first)) { + conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first)); + } } - GraphSafeRemoveNodes(graph, {activation, conv_out}); + if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) { + act_type = BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate")) + ? "gelu_tanh" + : "gelu_erf"; + conv_op->SetAttr("fuse_alpha", 0.0f); + conv_op->SetAttr("fuse_beta", 0.0f); + } + conv_op->SetAttr("fuse_activation", act_type); + conv_op->SetOutput("Output", {activation_out->Name()}); - PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL, - platform::errors::InvalidArgument( - "Subgraph has to contain conv input node.")); IR_NODE_LINK_TO(conv, activation_out); + GraphSafeRemoveNodes(graph, {activation, conv_out}); found_conv_activation_count++; }; gpd(graph, handler); - AddStatis(found_conv_activation_count); + if (!Has("disable_logs") || !Get("disable_logs")) { + PrettyLogDetail("--- fused %d conv with %s activation", + found_conv_activation_count, act_type); + } } -ConvActivationFusePass::ConvActivationFusePass() { +ConvActivationMkldnnFusePass::ConvActivationMkldnnFusePass() { AddOpCompat(OpCompat("conv2d")) .AddInput("Input") .IsTensor() @@ -142,8 +141,6 @@ ConvActivationFusePass::ConvActivationFusePass() { .AddAttr("paddings") .IsType>() .End() - // IsStringIn({"EXPLICIT", "SAME", "VALID"}), MobileNetV2 has no this - // attribute .AddAttr("padding_algorithm") .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) @@ -154,7 +151,6 @@ ConvActivationFusePass::ConvActivationFusePass() { .AddAttr("dilations") .IsType>() .End() - // IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute .AddAttr("data_format") .IsOptional() .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) @@ -167,8 +163,7 @@ ConvActivationFusePass::ConvActivationFusePass() { .AddOutput("Out") .IsTensor() .End(); -} -Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() { + AddOpCompat(OpCompat("leaky_relu")) .AddInput("X") .IsTensor() @@ -176,12 +171,10 @@ Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() { .AddOutput("Out") .IsTensor() .End() - // float, default=0.02 .AddAttr("alpha") .IsType() .End(); -} -Conv2DReLU6FusePass::Conv2DReLU6FusePass() { + AddOpCompat(OpCompat("relu6")) .AddInput("X") .IsTensor() @@ -189,12 +182,10 @@ Conv2DReLU6FusePass::Conv2DReLU6FusePass() { .AddOutput("Out") .IsTensor() .End() - // default = 6.0f .AddAttr("threshold") .IsType() .End(); -} -Conv2DSwishFusePass::Conv2DSwishFusePass() { + AddOpCompat(OpCompat("swish")) .AddInput("X") .IsTensor() @@ -205,8 +196,7 @@ Conv2DSwishFusePass::Conv2DSwishFusePass() { .AddAttr("beta") .IsType() .End(); -} -Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { + AddOpCompat(OpCompat("hard_swish")) .AddInput("X") .IsTensor() @@ -214,23 +204,19 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { .AddOutput("Out") .IsTensor() .End() - // float, optional, default=6.0 .AddAttr("threshold") .IsOptional() .IsType() .End() - // float, optional, default=6.0 .AddAttr("scale") .IsOptional() .IsType() .End() - // float, optional, default=3.0 .AddAttr("offset") .IsOptional() .IsType() .End(); -} -Conv2DMishFusePass::Conv2DMishFusePass() { + AddOpCompat(OpCompat("mish")) .AddInput("X") .IsTensor() @@ -238,8 +224,7 @@ Conv2DMishFusePass::Conv2DMishFusePass() { .AddOutput("Out") .IsTensor() .End(); -} -Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() { + AddOpCompat(OpCompat("hard_sigmoid")) .AddInput("X") .IsTensor() @@ -247,19 +232,15 @@ Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() { .AddOutput("Out") .IsTensor() .End() - // optional, default=0.2 .AddAttr("slope") .IsOptional() .IsType() .End() - // optional, default=0.5 .AddAttr("offset") .IsOptional() .IsType() .End(); -} -Conv2DGeluFusePass::Conv2DGeluFusePass() { AddOpCompat(OpCompat("gelu")) .AddInput("X") .IsTensor() @@ -270,6 +251,38 @@ Conv2DGeluFusePass::Conv2DGeluFusePass() { .AddAttr("approximate") .IsType() .End(); + + AddOpCompat(OpCompat("tanh")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("sqrt")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("abs")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); } } // namespace ir @@ -277,68 +290,22 @@ Conv2DGeluFusePass::Conv2DGeluFusePass() { } // namespace paddle REGISTER_PASS(conv_activation_mkldnn_fuse_pass, - paddle::framework::ir::ConvActivationFusePass); - -REGISTER_PASS(conv_relu_mkldnn_fuse_pass, - paddle::framework::ir::ConvActivationFusePass); -REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .LE("conv2d", 1) - .EQ("relu", 0)); - -REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, - paddle::framework::ir::Conv2DLeakyReLUFusePass); -REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .LE("conv2d", 1) - .LE("leaky_relu", 1)); - -REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, - paddle::framework::ir::Conv2DReLU6FusePass); -REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .LE("conv2d", 1) - .EQ("relu6", 0)); - -REGISTER_PASS(conv_swish_mkldnn_fuse_pass, - paddle::framework::ir::Conv2DSwishFusePass); -REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .LE("conv2d", 1) - .EQ("swish", 0)); - -REGISTER_PASS(conv_hard_swish_mkldnn_fuse_pass, - paddle::framework::ir::Conv2DHardSwishFusePass); -REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .LE("conv2d", 1) - .EQ("hard_swish", 0)); - -REGISTER_PASS(conv_mish_mkldnn_fuse_pass, - paddle::framework::ir::Conv2DMishFusePass); -REGISTER_PASS_CAPABILITY(conv_mish_mkldnn_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .LE("conv2d", 1) - .EQ("mish", 1)); - -REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass, - paddle::framework::ir::Conv2DHardSigmoidFusePass); -REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .LE("conv2d", 1) - .EQ("hard_sigmoid", 0)); + paddle::framework::ir::ConvActivationMkldnnFusePass); -REGISTER_PASS(conv_gelu_mkldnn_fuse_pass, - paddle::framework::ir::Conv2DGeluFusePass); -REGISTER_PASS_CAPABILITY(conv_gelu_mkldnn_fuse_pass) +REGISTER_PASS_CAPABILITY(conv_activation_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("conv2d", 1) - .EQ("gelu", 0)); + .EQ("abs", 0) + .LE("clip", 1) + .EQ("gelu", 0) + .EQ("hard_sigmoid", 0) + .LE("hard_swish", 0) + .LE("leaky_relu", 1) + .LE("mish", 1) + .EQ("relu", 0) + .EQ("relu6", 0) + .EQ("sigmoid", 0) + .EQ("sqrt", 0) + .EQ("swish", 0) + .EQ("tanh", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index 1a3a3232ddee244d85ce50f61b66b625f456edf9..6259e04fcab40f43e9f69af2f90650dbc71b8f6d 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,84 +18,22 @@ #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { namespace ir { -/* - * Fuse Conv and Activation base class. - */ -class Graph; -class ConvActivationFusePass : public FusePassBase { +class ConvActivationMkldnnFusePass : public FusePassBase { public: - ConvActivationFusePass(); - virtual ~ConvActivationFusePass() {} - virtual std::string conv_type() const { return "conv2d"; } - virtual std::string activation_type() const { return "relu"; } + ConvActivationMkldnnFusePass(); + virtual ~ConvActivationMkldnnFusePass() {} protected: - void ApplyImpl(ir::Graph* graph) const override; - const std::string name_scope_{"conv_activation_mkldnn_fuse"}; -}; -/* - * Fuse Conv and LeakyReLU class - */ -class Conv2DLeakyReLUFusePass : public ConvActivationFusePass { - public: - Conv2DLeakyReLUFusePass(); - std::string activation_type() const { return "leaky_relu"; } -}; -/* - * Fuse Conv and BoundedReLU class - */ -class Conv2DReLU6FusePass : public ConvActivationFusePass { - public: - Conv2DReLU6FusePass(); - std::string activation_type() const { return "relu6"; } -}; -/* - * Fuse Conv and Swish class - */ -class Conv2DSwishFusePass : public ConvActivationFusePass { - public: - Conv2DSwishFusePass(); - std::string activation_type() const { return "swish"; } -}; -/* - * Fuse Conv and HardSwish class - */ -class Conv2DHardSwishFusePass : public ConvActivationFusePass { - public: - Conv2DHardSwishFusePass(); - std::string activation_type() const { return "hard_swish"; } -}; -/* - * Fuse Conv and Mish class - */ -class Conv2DMishFusePass : public ConvActivationFusePass { - public: - Conv2DMishFusePass(); - std::string activation_type() const { return "mish"; } -}; -/* - * Fuse Conv and HardSigmoid class - */ -class Conv2DHardSigmoidFusePass : public ConvActivationFusePass { - public: - Conv2DHardSigmoidFusePass(); - std::string activation_type() const { return "hard_sigmoid"; } -}; + void ApplyImpl(Graph *graph) const override; -/* - * Fuse Conv and Gelu class - */ -class Conv2DGeluFusePass : public ConvActivationFusePass { - public: - Conv2DGeluFusePass(); - std::string activation_type() const { return "gelu"; } + void FuseConvAct( + Graph *graph, const std::string &conv_type, std::string &act_type, + const std::unordered_map &attrs_map) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc index e3db85471766f11547e68246c33463ea53001959..4e4560c2d52dcd4fef686c4e7cd983f7735514d8 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -104,8 +104,7 @@ void MainTest(std::string activation) { std::unique_ptr graph(new ir::Graph(prog)); - auto pass = - PassRegistry::Instance().Get("conv_" + activation + "_mkldnn_fuse_pass"); + auto pass = PassRegistry::Instance().Get("conv_activation_mkldnn_fuse_pass"); int original_nodes_num = graph->Nodes().size(); diff --git a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc index 2a8a248a99faf3037c27212bc6bd13e0955b6c71..d211e39d38d050ae7cbbf85e7add32988f181e19 100644 --- a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc @@ -27,7 +27,7 @@ using string::PrettyLogDetail; void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const { std::vector act_types = { - "relu", "tanh", "leaky_relu", "swish", "hardswish", "sqrt", + "relu", "tanh", "leaky_relu", "swish", "hard_swish", "sqrt", "abs", "clip", "gelu", "relu6", "sigmoid"}; std::vector elt_types = {"elementwise_add", "elementwise_sub", "elementwise_mul"}; @@ -56,7 +56,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( const std::unordered_map &attr_map) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - FusePassBase::Init("elementwise_act", graph); + FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; auto *elementwise_input = gpd.mutable_pattern() diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 8d30192e9a6a1e43b4a2e3c1a30c00f4cb5a93f9..261d3db92f590dd431afa0071b522d94423fb65e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -302,15 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() { // "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass", - "conv_relu_mkldnn_fuse_pass", // - "conv_leaky_relu_mkldnn_fuse_pass", // - "conv_relu6_mkldnn_fuse_pass", // - "conv_swish_mkldnn_fuse_pass", // - "conv_hard_swish_mkldnn_fuse_pass", // - "conv_mish_mkldnn_fuse_pass", // - "conv_hard_sigmoid_mkldnn_fuse_pass", // - // TODO(baoachun) fix int8 accuracy - "conv_gelu_mkldnn_fuse_pass", + "conv_activation_mkldnn_fuse_pass", // "scale_matmul_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", // "reshape_transpose_matmul_v2_mkldnn_fuse_pass", // @@ -403,14 +395,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("conv_transpose_bias_mkldnn_fuse_pass"); passes_.push_back("conv_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("conv_concat_relu_mkldnn_fuse_pass"); - passes_.push_back("conv_relu_mkldnn_fuse_pass"); - passes_.push_back("conv_leaky_relu_mkldnn_fuse_pass"); - passes_.push_back("conv_relu6_mkldnn_fuse_pass"); - passes_.push_back("conv_swish_mkldnn_fuse_pass"); - passes_.push_back("conv_hard_swish_mkldnn_fuse_pass"); - passes_.push_back("conv_mish_mkldnn_fuse_pass"); - passes_.push_back("conv_hard_sigmoid_mkldnn_fuse_pass"); - passes_.push_back("conv_gelu_mkldnn_fuse_pass"); + passes_.push_back("conv_activation_mkldnn_fuse_pass"); passes_.push_back("fc_fuse_pass"); passes_.push_back("repeated_fc_relu_fuse_pass"); passes_.push_back("fc_mkldnn_pass"); diff --git a/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc index 08f26bae37beaebcbc17affc15a508450774b61f..8c7ed7ffa29aa044d06064150096f1d7de679951 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc @@ -104,12 +104,12 @@ TEST(Analyzer_vit_ocr, fuse_status) { SetConfig(&cfg, true); int num_ops; auto predictor = CreatePaddlePredictor(cfg); - auto fuse_status = GetFuseStatis( + auto fuse_statis = GetFuseStatis( static_cast(predictor.get()), &num_ops); - CHECK_EQ(fuse_status.at("fc_mkldnn_pass"), 33); - CHECK_EQ(fuse_status.at("conv_activation_mkldnn_fuse"), 2); - CHECK_EQ(fuse_status.at("fc_elementwise_add_mkldnn_fuse"), 16); + CHECK_EQ(fuse_statis.at("fc_mkldnn_pass"), 33); + CHECK_EQ(fuse_statis.at("conv2d_gelu_mkldnn_fuse_pass"), 2); + CHECK_EQ(fuse_statis.at("fc_elementwise_add_mkldnn_fuse"), 16); } #endif diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index 070bf9511a9fec6c32d11030ec23f0ee839d8562..3e44e3096fd1989886fc034160c927ea9be12ede 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -61,27 +61,10 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { ? ctx.Attr("activation_beta") : 0.0f; - static std::unordered_map algo_map = { - {"relu", dnnl::algorithm::eltwise_relu}, - {"tanh", dnnl::algorithm::eltwise_tanh}, - {"leaky_relu", dnnl::algorithm::eltwise_relu}, - {"swish", dnnl::algorithm::eltwise_swish}, - {"hardswish", dnnl::algorithm::eltwise_hardswish}, - {"sqrt", dnnl::algorithm::eltwise_sqrt}, - {"abs", dnnl::algorithm::eltwise_abs}, - {"clip", dnnl::algorithm::eltwise_clip}, - {"gelu", dnnl::algorithm::eltwise_gelu_erf}, - {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, - {"relu6", dnnl::algorithm::eltwise_bounded_relu}, - {"sigmoid", dnnl::algorithm::eltwise_logistic}}; - - const auto& activation_type = - algo_map.find(ctx.Attr("activation_type")); - - if (activation_type != algo_map.end()) { - post_operations.append_eltwise(scale, activation_type->second, alpha, - beta); - } + const auto activation_algorithm = platform::AcquireActivationAlgorithm( + ctx.Attr("activation_type")); + + post_operations.append_eltwise(scale, activation_algorithm, alpha, beta); } return post_operations; } diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 7b790a6081ed73a038238db7c66b06638fde4075..a2828b978ed44fb6ea0f334f5fa4c583ccc4babe 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -505,41 +505,20 @@ class ConvMKLDNNHandlerT if (fuse_residual_conn) { post_operations.append_sum(sum_scale); } - // Fusion with ReLU layer is executed through the PostOps feature. Create a - // PostOps object and configure it to execute an eltwise relu operation. - if (fuse_activation == "relu" || fuse_activation == "leaky_relu") { - post_operations.append_eltwise(activation_scale, - dnnl::algorithm::eltwise_relu, fuse_alpha, - fuse_beta); - } else if (fuse_activation == "relu6") { - post_operations.append_eltwise(activation_scale, - dnnl::algorithm::eltwise_bounded_relu, - fuse_alpha, fuse_beta); - } else if (fuse_activation == "swish") { - post_operations.append_eltwise(activation_scale, - dnnl::algorithm::eltwise_swish, fuse_alpha, - fuse_beta); - } else if (fuse_activation == "hard_swish") { - post_operations.append_eltwise(activation_scale, - dnnl::algorithm::eltwise_hardswish, - fuse_alpha, fuse_beta); - } else if (fuse_activation == "mish") { - post_operations.append_eltwise(activation_scale, - dnnl::algorithm::eltwise_mish, fuse_alpha, - fuse_beta); - } else if (fuse_activation == "hard_sigmoid") { + + if (fuse_activation == "hard_sigmoid") { post_operations.append_eltwise(activation_scale, dnnl::algorithm::eltwise_linear, fuse_alpha, fuse_beta); post_operations.append_eltwise(activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f); - } else if (fuse_activation == "gelu_tanh") { - post_operations.append_eltwise( - activation_scale, dnnl::algorithm::eltwise_gelu_tanh, 0.0f, 0.0f); - } else if (fuse_activation == "gelu_erf") { - post_operations.append_eltwise( - activation_scale, dnnl::algorithm::eltwise_gelu_erf, 0.0f, 0.0f); + } else if (fuse_activation != "") { + const auto activation_algorithm = + platform::AcquireActivationAlgorithm(fuse_activation); + post_operations.append_eltwise(activation_scale, activation_algorithm, + fuse_alpha, fuse_beta); } + conv_attr.set_post_ops(post_operations); return conv_attr; } diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 382f96e83bfce5e53a4e7bf69c74f1655d2a9dda..3e0fabaac25c8de4df01f426146f4e160fe4c977 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -951,6 +951,33 @@ class ActivationMKLDNNHandler } }; +static const dnnl::algorithm AcquireActivationAlgorithm( + std::string activation_name) { + std::unordered_map activation_map = { + {"abs", dnnl::algorithm::eltwise_abs}, + {"clip", dnnl::algorithm::eltwise_clip}, + {"gelu", dnnl::algorithm::eltwise_gelu_erf}, + {"gelu_erf", dnnl::algorithm::eltwise_gelu_erf}, + {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, + {"hard_swish", dnnl::algorithm::eltwise_hardswish}, + {"leaky_relu", dnnl::algorithm::eltwise_relu}, + {"mish", dnnl::algorithm::eltwise_mish}, + {"relu", dnnl::algorithm::eltwise_relu}, + {"relu6", dnnl::algorithm::eltwise_bounded_relu}, + {"sigmoid", dnnl::algorithm::eltwise_logistic}, + {"sqrt", dnnl::algorithm::eltwise_sqrt}, + {"swish", dnnl::algorithm::eltwise_swish}, + {"tanh", dnnl::algorithm::eltwise_tanh}}; + + const auto& activation_type = activation_map.find(activation_name); + + PADDLE_ENFORCE_NE(activation_type, activation_map.end(), + platform::errors::InvalidArgument( + "Activation '%s' not found in oneDNN algorithms mapper", + activation_name)); + return activation_type->second; +} + class ReorderMKLDNNHandler { public: ReorderMKLDNNHandler(std::vector& dims, // NOLINT diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 76feab207eebdb78cf240c1d793aa7232ab08ff5..49dcda0cca14c008297a7629c191821db2b8ab41 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -432,14 +432,7 @@ class Quant2Int8MkldnnPass(object): graph = self._apply_pass(graph, 'conv_transpose_bias_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_concat_relu_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_leaky_relu_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_swish_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_hard_swish_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_mish_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_hard_sigmoid_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_gelu_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_activation_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'fc_fuse_pass', ['use_gpu', 'use_fc_padding'], [False, False]) graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py index 1516d1dafd32f52276e4c823dbab07c696d58785..2a3e349a18a7d2bd3ae4b93353225cd0dbbf6e46 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py @@ -170,28 +170,24 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): # 11. Generate legal attr of act act_op = None - self.passes = None + self.passes = ["conv_activation_mkldnn_fuse_pass"] if act_type == "relu6": - self.passes = ["conv_relu6_mkldnn_fuse_pass"] threshold = draw(st.floats(min_value=1.0, max_value=10.0)) act_op = OpConfig("relu6", inputs={"X": ["conv2d_out"]}, outputs={"Out": ["relu_out"]}, threshold=threshold) - if act_type == "leaky_relu": - self.passes = ["conv_leaky_relu_mkldnn_fuse_pass"] + elif act_type == "leaky_relu": alpha = draw(st.floats(min_value=0.1, max_value=1.0)) act_op = OpConfig("leaky_relu", inputs={"X": ["conv2d_out"]}, outputs={"Out": ["relu_out"]}, alpha=alpha) - if act_type == "relu": - self.passes = ["conv_relu_mkldnn_fuse_pass"] + elif act_type == "relu": act_op = OpConfig("relu", inputs={"X": ["conv2d_out"]}, outputs={"Out": ["relu_out"]}) - if act_type == "swish": - self.passes = ["conv_swish_mkldnn_fuse_pass"] + elif act_type == "swish": beta = draw(st.floats(min_value=0.1, max_value=1.0)) act_op = OpConfig("swish", inputs={"X": ["conv2d_out"]}, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py index 645ca2202648dd884247e813e1b2eb9b712b55fd..0c4bfc8ee26043470cf2f07596419b5afebb849e 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py @@ -18,8 +18,6 @@ import unittest import numpy as np from inference_pass_test import InferencePassTest import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.core import AnalysisConfig from paddle.fluid.core import PassVersionChecker @@ -42,13 +40,13 @@ class ConvActivationMkldnnFusePassTest(InferencePassTest): } self.fetch_list = [conv_out] self.enable_mkldnn = True + self.pass_name = 'conv_activation_mkldnn_fuse_pass' def set_params(self): self.conv_num_filters = 3 self.conv_filter_size = 3 self.conv_bias_attr = False self.act = "relu" - self.pass_name = 'conv_relu_mkldnn_fuse_pass' def test_check_output(self): use_gpu = False @@ -65,7 +63,6 @@ class ConvActivationMkldnnFusePassTest_1(ConvActivationMkldnnFusePassTest): self.conv_filter_size = 5 self.conv_bias_attr = True self.act = "relu" - self.pass_name = 'conv_relu_mkldnn_fuse_pass' class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest): @@ -75,7 +72,6 @@ class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest): self.conv_filter_size = 3 self.conv_bias_attr = False self.act = "leaky_relu" - self.pass_name = 'conv_leaky_relu_mkldnn_fuse_pass' class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest): @@ -85,7 +81,6 @@ class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest): self.conv_filter_size = 5 self.conv_bias_attr = True self.act = "leaky_relu" - self.pass_name = 'conv_leaky_relu_mkldnn_fuse_pass' class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest): @@ -95,7 +90,6 @@ class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest): self.conv_filter_size = 3 self.conv_bias_attr = False self.act = "relu6" - self.pass_name = 'conv_relu6_mkldnn_fuse_pass' class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest): @@ -105,7 +99,6 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest): self.conv_filter_size = 5 self.conv_bias_attr = True self.act = "hard_swish" - self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass' class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest): @@ -115,7 +108,6 @@ class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest): self.conv_filter_size = 5 self.conv_bias_attr = True self.act = "mish" - self.pass_name = 'conv_mish_mkldnn_fuse_pass' class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest): @@ -125,7 +117,6 @@ class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest): self.conv_filter_size = 5 self.conv_bias_attr = True self.act = "hard_sigmoid" - self.pass_name = 'conv_hard_sigmoid_mkldnn_fuse_pass' if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py index 65634972117e2ee3aecb71f9dff266e24bf4c8b7..dff23e96dd244d28fce21e43560d21ad93126dc8 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py @@ -102,7 +102,8 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest): yield config, ["conv2d"], (1e-5, 1e-5) def test(self): - self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"]) + self.run_and_statis(quant=False, + passes=["conv_activation_mkldnn_fuse_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py index d62770bf758d2741f62b9449f0e797a40aebe927..1eb325f1d0879975428c4ae6dd0d50e806c609a0 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py @@ -104,7 +104,7 @@ class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest): def test(self): self.run_and_statis(quant=False, - passes=["conv_hard_sigmoid_mkldnn_fuse_pass"]) + passes=["conv_activation_mkldnn_fuse_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py index ad54ca3d91e21916d4be519cf2eea45edc855628..a0ed73964f3770a413300188eb2ea20fa386003b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py @@ -106,7 +106,7 @@ class TestConvHardSwishMkldnnFusePass(PassAutoScanTest): def test(self): self.run_and_statis(quant=False, - passes=["conv_hard_swish_mkldnn_fuse_pass"]) + passes=["conv_activation_mkldnn_fuse_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_mish_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_mish_fuse_pass.py index 365ba5346e392c732eb6182765ef80a094bb0199..6f5810cb802d9f677a547eee9f8a3506d17050a6 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_mish_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_mish_fuse_pass.py @@ -99,7 +99,8 @@ class TestConvMishMkldnnFusePass(PassAutoScanTest): yield config, ["conv2d"], (1e-5, 1e-5) def test(self): - self.run_and_statis(quant=False, passes=["conv_mish_mkldnn_fuse_pass"]) + self.run_and_statis(quant=False, + passes=["conv_activation_mkldnn_fuse_pass"]) if __name__ == "__main__":