diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index f0949cb9dfbd2547826c6c910717808a44809bb7..e811475dd83e9c8145bef60a39570b8dfe358de1 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -931,65 +931,22 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, return bn_out_var; } -PDNode *patterns::ConvActivation::operator()( - paddle::framework::ir::PDNode *conv_input, - std::string conv_type, - std::string activation_type) { - // Create Operators - conv_input->assert_is_op_input(conv_type, "Input"); - auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type); - auto *activation_op = - pattern->NewNode(activation_repr())->assert_is_op(activation_type); - // Create variables - // Filter - auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input(conv_type, "Filter"); - // intermediate variable, will be removed in the IR after fuse. - auto *conv_out_var = pattern->NewNode(conv_out_repr()) - ->AsIntermediate() - ->assert_is_only_output_of_op(conv_type) - ->assert_is_op_input(activation_type); - // output - auto *activation_out_var = pattern->NewNode(activation_out_repr()) - ->AsOutput() - ->assert_is_op_output(activation_type); - - conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); - activation_op->LinksFrom({conv_out_var}).LinksTo({activation_out_var}); - return activation_out_var; -} - -PDNode *patterns::ElementwiseActivation::operator()( - paddle::framework::ir::PDNode *elementwise_a, - const std::string &elementwise_type, - const std::string &activation_type) { - // Create Operators - elementwise_a->assert_is_op_input(elementwise_type, "X"); - auto *elementwise_op = - pattern->NewNode(elementwise_repr())->assert_is_op(elementwise_type); +PDNode *patterns::OperatorActivation::operator()( + const std::string &operator_type, const std::string &activation_type) { + auto *preceding_op = + pattern->NewNode(preceding_op_repr())->assert_is_op(operator_type); + auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op(operator_type) + ->assert_is_op_input(activation_type); auto *activation_op = pattern->NewNode(activation_repr())->assert_is_op(activation_type); - // Create variables - auto *elementwise_b = pattern->NewNode(elementwise_b_repr()) - ->AsInput() - ->assert_is_op_input(elementwise_type, "Y"); - // intermediate variable, will be removed in the IR after fuse. - auto *elementwise_out_var = - pattern->NewNode(elementwise_out_repr()) - ->AsIntermediate() - ->assert_is_only_output_of_op(elementwise_type) - ->assert_is_op_input(activation_type); - // output - auto *activation_out_var = pattern->NewNode(activation_out_repr()) - ->AsOutput() - ->assert_is_op_output(activation_type); - - elementwise_op->LinksFrom({elementwise_a, elementwise_b}) - .LinksTo({elementwise_out_var}); - activation_op->LinksFrom({elementwise_out_var}).LinksTo({activation_out_var}); - return activation_out_var; + auto *activation_out = pattern->NewNode(activation_out_repr()) + ->AsOutput() + ->assert_is_op_output(activation_type); + preceding_op->LinksTo({preceding_op_out}); + activation_op->LinksFrom({preceding_op_out}).LinksTo({activation_out}); + return activation_out; } PDNode *patterns::SeqConvEltAddRelu::operator()( @@ -1121,44 +1078,6 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x, return fc_out_var; } -PDNode *patterns::FCActOneDNN::operator()(const std::string &act_type) { - auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc"); - auto *fc_out = pattern->NewNode(fc_out_repr()) - ->assert_is_op_output("fc", "Out") - ->assert_is_op_input(act_type); - auto *act = - pattern->NewNode(act_repr())->assert_is_op(act_type)->AsIntermediate(); - auto *act_out = pattern->NewNode(act_out_repr()) - ->assert_is_op_output(act_type, "Out") - ->AsOutput(); - - fc->LinksTo({fc_out}); - act->LinksFrom({fc_out}).LinksTo({act_out}); - - return act_out; -} - -PDNode *patterns::SoftplusActivation::operator()(std::string activation_type) { - // Create Operators - auto *softplus_op = - pattern->NewNode(softplus_repr())->assert_is_op("softplus"); - auto *activation_op = - pattern->NewNode(activation_repr())->assert_is_op(activation_type); - // intermediate variable, will be removed in the IR after fuse. - auto *softplus_out = pattern->NewNode(softplus_out_repr()) - ->AsIntermediate() - ->assert_is_only_output_of_op("softplus") - ->assert_is_op_input(activation_type); - // output - auto *activation_out = pattern->NewNode(activation_out_repr()) - ->AsOutput() - ->assert_is_op_output(activation_type); - - softplus_op->LinksTo({softplus_out}); - activation_op->LinksFrom({softplus_out}).LinksTo({activation_out}); - return activation_out; -} - PDNode *patterns::Embedding::operator()(PDNode *x) { x->assert_is_op_input("lookup_table", "Ids"); auto *lookup_table_op = diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index be14ef2dbf3ea69dcb05edf4466c4429c3012722..9210cecabe7c682a7b96b072b014e1ba041ca526 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -524,49 +524,16 @@ struct ConvBN : public PatternBase { PATTERN_DECL_NODE(bn_saved_variance); }; -// Conv with Activation -// op: conv + activation -// named nodes: -// conv_input, conv_weight, -// conv_out, conv, -// activation_out, activation -struct ConvActivation : public PatternBase { - ConvActivation(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "conv_activation") {} - - PDNode* operator()(PDNode* conv_input, - std::string conv_type = "conv2d", - std::string activation_type = "relu"); +struct OperatorActivation : public PatternBase { + OperatorActivation(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "operator_activation") {} - // declare operator node's name - PATTERN_DECL_NODE(conv); - PATTERN_DECL_NODE(activation); - // declare variable node's name - PATTERN_DECL_NODE(conv_weight); - PATTERN_DECL_NODE(conv_out); - PATTERN_DECL_NODE(activation_out); -}; - -// Elementwise with Activation -// op: elementwise + activation -// named nodes: -// elementwise_a, elementwise_b, -// elementwise_out, elementwise, -// activation_out, activation -struct ElementwiseActivation : public PatternBase { - ElementwiseActivation(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "elementwise_add_activation") {} - - PDNode* operator()(PDNode* elementwise_a, - const std::string& elementwise_type, + PDNode* operator()(const std::string& operator_type, const std::string& activation_type); - // declare operator node's name - PATTERN_DECL_NODE(elementwise); + PATTERN_DECL_NODE(preceding_op); + PATTERN_DECL_NODE(preceding_op_out); PATTERN_DECL_NODE(activation); - // declare variable node's name - PATTERN_DECL_NODE(elementwise_b); - PATTERN_DECL_NODE(elementwise_out); PATTERN_DECL_NODE(activation_out); }; @@ -639,45 +606,6 @@ struct FCMKLDNN : public PatternBase { PATTERN_DECL_NODE(output); }; -// -// \brief Pattern looking for fc and a directly following activation -// operator. -// -// \note Currently only gelu and tanh are supported as an activation -// function. -// Formula: act(fc(x)) -// Op: fc + act -struct FCActOneDNN : public PatternBase { - FCActOneDNN(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "fc_act_onednn") {} - - PDNode* operator()(const std::string& act_type); - - // declare operator node's name - PATTERN_DECL_NODE(fc); - PATTERN_DECL_NODE(act); - PATTERN_DECL_NODE(fc_out); - PATTERN_DECL_NODE(act_out); -}; - -// Fuse softplus with activation -// ops: softplus + activation -// nodes: -// softplus, softplus_out, -// activation, activation_out -struct SoftplusActivation : public PatternBase { - SoftplusActivation(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "softplus_activation") {} - - PDNode* operator()(std::string activation_type); - - // declare operator node's name - PATTERN_DECL_NODE(softplus); - PATTERN_DECL_NODE(activation); - PATTERN_DECL_NODE(softplus_out); - PATTERN_DECL_NODE(activation_out); -}; - // Embedding struct Embedding : public PatternBase { Embedding(PDPattern* pattern, const std::string& name_scope) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 4eefc2987bcb44c7ee40e8f9ff8d191a7eac1a71..8c140e8132489a01f8527675f9af58f235abc294 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -24,61 +25,27 @@ namespace ir { using string::PrettyLogDetail; void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { - std::vector act_types = {"relu", - "mish", - "swish", - "sqrt", - "hard_swish", - "sigmoid", - "abs", - "gelu", - "relu6", - "clip", - "tanh", - "hard_sigmoid", - "leaky_relu"}; + auto act_types = paddle::platform::GetSupportedActivations(); std::vector conv_types = {"conv2d"}; for (const auto& conv_type : conv_types) for (auto& act_type : act_types) { - std::unordered_map attrs_map; - - if (act_type == "swish") - attrs_map.emplace("beta", "fuse_alpha"); - else if (act_type == "relu6") - attrs_map.emplace("threshold", "fuse_alpha"); - else if (act_type == "hard_sigmoid") { - attrs_map.emplace("slope", "fuse_alpha"); - attrs_map.emplace("offset", "fuse_beta"); - } else if (act_type == "clip") { - attrs_map.emplace("min", "fuse_alpha"); - attrs_map.emplace("max", "fuse_beta"); - } else { - attrs_map.emplace("alpha", "fuse_alpha"); - attrs_map.emplace("beta", "fuse_beta"); - } - FuseConvAct(graph, conv_type, act_type, attrs_map); + FuseConvAct(graph, conv_type, act_type); } } -void ConvActivationMkldnnFusePass::FuseConvAct( - Graph* graph, - const std::string& conv_type, - std::string& act_type, - const std::unordered_map& attrs_map) const { +void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, + const std::string& conv_type, + std::string& act_type) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; - auto* conv_input = gpd.mutable_pattern() - ->NewNode("conv_activation_mkldnn_fuse/conv_input") - ->AsInput() - ->assert_is_op_input(conv_type, "Input"); - patterns::ConvActivation conv_act_pattern(gpd.mutable_pattern(), - "conv_activation_mkldnn_fuse"); - conv_act_pattern(conv_input, conv_type, act_type); + patterns::OperatorActivation conv_act_pattern(gpd.mutable_pattern(), + "conv_activation_mkldnn_fuse"); + conv_act_pattern(conv_type, act_type); int found_conv_activation_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -90,16 +57,16 @@ void ConvActivationMkldnnFusePass::FuseConvAct( return; } - GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, conv_act_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_act_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_act_pattern); - GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv, preceding_op, conv_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_out, preceding_op_out, conv_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(activation, activation, conv_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern); OpDesc* conv_op = conv->Op(); OpDesc* act_op = activation->Op(); - for (const auto& attrs : attrs_map) { + auto attr_map = paddle::platform::GetAttributeMap(act_type); + for (const auto& attrs : attr_map) { if (act_op->HasAttr(attrs.first)) { conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first)); } diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index e1e2898384609ec02d3a45d32ac55e6174974d3d..11925e1992df4c823307e520c8670c7a1fdc038f 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -31,11 +31,9 @@ class ConvActivationMkldnnFusePass : public FusePassBase { protected: void ApplyImpl(Graph *graph) const override; - void FuseConvAct( - Graph *graph, - const std::string &conv_type, - std::string &act_type, - const std::unordered_map &attrs_map) const; + void FuseConvAct(Graph *graph, + const std::string &conv_type, + std::string &act_type) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc index a96ce5e297a87e3b80b0db39fd73db0ed2c2ec15..c9eee31606cc36e54fd20c99c5508a76bed344a8 100644 --- a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -26,71 +27,40 @@ namespace ir { using string::PrettyLogDetail; void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const { - std::vector act_types = {"relu", - "tanh", - "leaky_relu", - "swish", - "hard_swish", - "sqrt", - "abs", - "clip", - "gelu", - "relu6", - "sigmoid"}; + auto act_types = paddle::platform::GetSupportedActivations(); std::vector elt_types = { "elementwise_add", "elementwise_sub", "elementwise_mul"}; for (const auto &elt_type : elt_types) for (const auto &act_type : act_types) { - std::unordered_map attr_map; - - if (act_type == "swish") - attr_map.emplace("beta", "activation_alpha"); - else if (act_type == "relu6") - attr_map.emplace("threshold", "activation_alpha"); - else if (act_type == "clip") { - attr_map.emplace("min", "activation_alpha"); - attr_map.emplace("max", "activation_beta"); - } else { - attr_map.emplace("alpha", "activation_alpha"); - attr_map.emplace("beta", "activation_beta"); - } - FuseElementwiseAct(graph, elt_type, act_type, attr_map); + FuseElementwiseAct(graph, elt_type, act_type); } } void ElementwiseActivationOneDNNPass::FuseElementwiseAct( Graph *graph, const std::string &elt_type, - const std::string &act_type, - const std::unordered_map &attr_map) const { + const std::string &act_type) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; - auto *elementwise_input = gpd.mutable_pattern() - ->NewNode(elt_type + "_act/elementwise_input") - ->AsInput() - ->assert_is_op_input(elt_type, "X"); - patterns::ElementwiseActivation elementwise_act_pattern(gpd.mutable_pattern(), - elt_type + "_act"); - elementwise_act_pattern(elementwise_input, elt_type, act_type); + patterns::OperatorActivation elementwise_act_pattern(gpd.mutable_pattern(), + elt_type + "_act"); + elementwise_act_pattern(elt_type, act_type); int found_elementwise_activation_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { VLOG(4) << "Fuse " << elt_type << " with activation op."; - // Elementwise output - GET_IR_NODE_FROM_SUBGRAPH( - elementwise_out, elementwise_out, elementwise_act_pattern); - // ACT output GET_IR_NODE_FROM_SUBGRAPH( - activation_out, activation_out, elementwise_act_pattern); - // ops + elementwise, preceding_op, elementwise_act_pattern); GET_IR_NODE_FROM_SUBGRAPH( - elementwise, elementwise, elementwise_act_pattern); + elementwise_out, preceding_op_out, elementwise_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(activation, activation, elementwise_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + activation_out, activation_out, elementwise_act_pattern); auto *elementwise_op = elementwise->Op(); @@ -106,6 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( } auto *activation_op = activation->Op(); + auto attr_map = paddle::platform::GetAttributeMap(act_type); for (const auto &attr : attr_map) { if (activation_op->HasAttr(attr.first)) { elementwise_op->SetAttr(attr.second, @@ -115,9 +86,9 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( if (act_type == "gelu" && activation_op->HasAttr("approximate") && BOOST_GET_CONST(bool, activation_op->GetAttr("approximate"))) - elementwise_op->SetAttr("activation_type", std::string("gelu_tanh")); + elementwise_op->SetAttr("fuse_activation", std::string("gelu_tanh")); else - elementwise_op->SetAttr("activation_type", act_type); + elementwise_op->SetAttr("fuse_activation", act_type); elementwise_op->SetOutput("Out", {activation_out->Name()}); @@ -146,14 +117,16 @@ REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass) .LE("elementwise_add", 1) .LE("elementwise_sub", 1) .LE("elementwise_mul", 1) - .LE("relu", 0) - .LE("tanh", 0) - .LE("leaky_relu", 1) - .LE("swish", 0) - .LE("hard_swish", 0) - .LE("sqrt", 0) - .LE("abs", 0) + .EQ("abs", 0) .LE("clip", 1) - .LE("gelu", 0) - .LE("relu6", 0) - .LE("sigmoid", 0)); + .EQ("gelu", 0) + .EQ("hard_sigmoid", 0) + .LE("hard_swish", 0) + .LE("leaky_relu", 1) + .LE("mish", 1) + .EQ("relu", 0) + .EQ("relu6", 0) + .EQ("sigmoid", 0) + .EQ("sqrt", 0) + .EQ("swish", 0) + .EQ("tanh", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h index 8df479e3ddf06d56567d862e6ba37740afde17d8..37bd5345ec78f487d4e7da90d1dfd69e218229c9 100644 --- a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h @@ -34,11 +34,9 @@ class ElementwiseActivationOneDNNPass : public FusePassBase { protected: void ApplyImpl(Graph *graph) const override; - void FuseElementwiseAct( - Graph *graph, - const std::string &elt_types, - const std::string &act_types, - const std::unordered_map &attr_map) const; + void FuseElementwiseAct(Graph *graph, + const std::string &elt_types, + const std::string &act_types) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc index 99243ec7d7047be783879bcecca5c495e79ca540..e5031c83aac160bbf71772fd3c556d0edc5041e5 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc @@ -39,20 +39,17 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, FusePassBase::Init("fc_act", graph); GraphPatternDetector gpd; - patterns::FCActOneDNN fc_act_pattern(gpd.mutable_pattern(), "fc_act"); - fc_act_pattern(act_type); + patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act"); + fc_act_pattern("fc", act_type); int found_fc_act_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { VLOG(4) << "Fuse fc with activation op."; - // FC output - GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_act_pattern); - // ACT output - GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fc_act_pattern); - // ops - GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_act_pattern); - GET_IR_NODE_FROM_SUBGRAPH(act, act, fc_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc, preceding_op, fc_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, preceding_op_out, fc_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act, activation, fc_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_out, activation_out, fc_act_pattern); auto *fc_op = fc->Op(); auto *act_op = act->Op(); diff --git a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc index 3dd850d886c8e02681293810b2d951b0e9c2c969..41e70e529bf73dbbc564b13384193ba6df61de4a 100644 --- a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -26,59 +27,34 @@ namespace ir { using string::PrettyLogDetail; void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const { - std::vector act_types = {"relu", - "tanh", - "leaky_relu", - "swish", - "hardswish", - "sqrt", - "abs", - "clip", - "gelu", - "relu6", - "sigmoid"}; + auto act_types = paddle::platform::GetSupportedActivations(); for (const auto &act_type : act_types) { - std::unordered_map attr_map; - - if (act_type == "swish") - attr_map.emplace("beta", "fuse_activation_alpha"); - else if (act_type == "relu6") - attr_map.emplace("threshold", "fuse_activation_alpha"); - else if (act_type == "clip") { - attr_map.emplace("min", "fuse_activation_alpha"); - attr_map.emplace("max", "fuse_activation_beta"); - } else { - attr_map.emplace("alpha", "fuse_activation_alpha"); - attr_map.emplace("beta", "fuse_activation_beta"); - } - FuseSoftplusActivation(graph, act_type, attr_map); + FuseSoftplusActivation(graph, act_type); } } void SoftplusActivationOneDNNPass::FuseSoftplusActivation( - Graph *graph, - const std::string &fuse_activation_type, - const std::unordered_map &attr_map) const { + Graph *graph, const std::string &act_type) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init("softplus_activation", graph); GraphPatternDetector gpd; - patterns::SoftplusActivation softplus_activation_pattern( + patterns::OperatorActivation softplus_activation_pattern( gpd.mutable_pattern(), "softplus_activation"); - softplus_activation_pattern(fuse_activation_type); + softplus_activation_pattern("softplus", act_type); int found_softplus_activation_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { VLOG(4) << "Fuse softplus with activation op."; GET_IR_NODE_FROM_SUBGRAPH( - softplus_out, softplus_out, softplus_activation_pattern); + softplus_out, preceding_op_out, softplus_activation_pattern); GET_IR_NODE_FROM_SUBGRAPH( activation_out, activation_out, softplus_activation_pattern); - - GET_IR_NODE_FROM_SUBGRAPH(softplus, softplus, softplus_activation_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softplus, preceding_op, softplus_activation_pattern); GET_IR_NODE_FROM_SUBGRAPH( activation, activation, softplus_activation_pattern); @@ -94,18 +70,18 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( } auto *activation_op = activation->Op(); + auto attr_map = paddle::platform::GetAttributeMap(act_type); for (const auto &attr : attr_map) { if (activation_op->HasAttr(attr.first)) { softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first)); } } - if (fuse_activation_type == "gelu" && - activation_op->HasAttr("approximate") && + if (act_type == "gelu" && activation_op->HasAttr("approximate") && BOOST_GET_CONST(bool, activation_op->GetAttr("approximate"))) - softplus_op->SetAttr("fuse_activation_type", std::string("gelu_tanh")); + softplus_op->SetAttr("fuse_activation", std::string("gelu_tanh")); else - softplus_op->SetAttr("fuse_activation_type", fuse_activation_type); + softplus_op->SetAttr("fuse_activation", act_type); softplus_op->SetAttr("use_mkldnn", true); @@ -121,7 +97,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( if (!Has("disable_logs") || !Get("disable_logs")) PrettyLogDetail("--- fused %d softplus with %s activation", found_softplus_activation_count, - fuse_activation_type); + act_type); } } // namespace ir } // namespace framework @@ -133,13 +109,16 @@ REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("softplus", 1) - .EQ("relu", 0) - .EQ("tanh", 0) - .LE("leaky_relu", 1) - .EQ("swish", 0) - .EQ("hard_swish", 0) - .EQ("sqrt", 0) .EQ("abs", 0) - .LE("relu6", 1) .LE("clip", 1) - .EQ("gelu", 0)); + .EQ("gelu", 0) + .EQ("hard_sigmoid", 0) + .LE("hard_swish", 0) + .LE("leaky_relu", 1) + .LE("mish", 1) + .EQ("relu", 0) + .EQ("relu6", 0) + .EQ("sigmoid", 0) + .EQ("sqrt", 0) + .EQ("swish", 0) + .EQ("tanh", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h index c49502c674355abf26302f91b21af3371d996d07..6368a102b0e852a1f3dc3969287a4ba5ad88462c 100644 --- a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h @@ -34,10 +34,8 @@ class SoftplusActivationOneDNNPass : public FusePassBase { protected: void ApplyImpl(ir::Graph *graph) const override; - void FuseSoftplusActivation( - ir::Graph *graph, - const std::string &fuse_activation_type, - const std::unordered_map &attr_map) const; + void FuseSoftplusActivation(ir::Graph *graph, + const std::string &act_type) const; }; } // namespace ir diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index 7f6566460ab6290c623c03567cc21f3cd24b77be..42d749b7b8e3e4cfb2cd0f222321f76ae08c2e3c 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -50,22 +50,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { private: dnnl::post_ops get_post_ops(const framework::ExecutionContext& ctx) const { dnnl::post_ops post_operations; - if (ctx.HasAttr("activation_type")) { - const float scale = ctx.HasAttr("activation_scale") - ? ctx.Attr("activation_scale") - : 1.0f; - const float alpha = ctx.HasAttr("activation_alpha") - ? ctx.Attr("activation_alpha") - : 0.0f; - const float beta = ctx.HasAttr("activation_beta") - ? ctx.Attr("activation_beta") - : 0.0f; - - const auto activation_algorithm = platform::AcquireActivationAlgorithm( - ctx.Attr("activation_type")); - - post_operations.append_eltwise(scale, activation_algorithm, alpha, beta); - } + platform::AppendActivation(ctx, post_operations); return post_operations; } diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 17d4c2fad96b8ed32ca8f6558c85a0fcddb6b88d..8ee97c281e3f4d677194c57c1f86b7c50f6a0eb3 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -553,10 +553,6 @@ class ConvMKLDNNHandlerT dnnl::primitive_attr conv_attr; dnnl::post_ops post_operations; - const std::string fuse_activation = - ctx.Attr("fuse_activation"); - const float fuse_alpha = ctx.Attr("fuse_alpha"); - const float fuse_beta = ctx.Attr("fuse_beta"); const bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); float sum_scale = 1.0f; @@ -587,19 +583,7 @@ class ConvMKLDNNHandlerT post_operations.append_sum(sum_scale); } - if (fuse_activation == "hard_sigmoid") { - post_operations.append_eltwise(activation_scale, - dnnl::algorithm::eltwise_linear, - fuse_alpha, - fuse_beta); - post_operations.append_eltwise( - activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f); - } else if (fuse_activation != "") { - const auto activation_algorithm = - platform::AcquireActivationAlgorithm(fuse_activation); - post_operations.append_eltwise( - activation_scale, activation_algorithm, fuse_alpha, fuse_beta); - } + platform::AppendActivation(ctx, post_operations, activation_scale); conv_attr.set_post_ops(post_operations); return conv_attr; diff --git a/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h b/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h index d2aa1cfc6bbf7b8fd38f5d57e2ccbb71cad58adf..c41864ee26f55bb313be2ef1275d61764823e23a 100644 --- a/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h +++ b/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h @@ -46,7 +46,7 @@ class SoftplusMKLDNNHandler 1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f); } - AppendFusedActivationIfExists(ctx, &post_ops); + platform::AppendActivation(ctx, post_ops); dnnl::primitive_attr attrs; attrs.set_post_ops(post_ops); @@ -62,42 +62,8 @@ class SoftplusMKLDNNHandler return this->AcquireMemoryFromPrimitive( this->fwd_pd_->src1_desc(), platform::to_void_cast(beta)); } - - private: - void AppendFusedActivationIfExists(const framework::ExecutionContext& ctx, - dnnl::post_ops* post_ops) { - const auto& fused_activation_type = - algo_map.find(ctx.Attr("fuse_activation_type")); - - if (fused_activation_type != algo_map.end()) { - auto scale_out = - ctx.Attr("fuse_activation_scale"); // for future int8 support - post_ops->append_eltwise(scale_out, - fused_activation_type->second, - ctx.Attr("fuse_activation_alpha"), - ctx.Attr("fuse_activation_beta")); - } - } - - static const std::unordered_map algo_map; }; -template -const std::unordered_map - SoftplusMKLDNNHandler::algo_map = { - {"relu", dnnl::algorithm::eltwise_relu}, - {"tanh", dnnl::algorithm::eltwise_tanh}, - {"leaky_relu", dnnl::algorithm::eltwise_relu}, - {"swish", dnnl::algorithm::eltwise_swish}, - {"hardswish", dnnl::algorithm::eltwise_hardswish}, - {"sqrt", dnnl::algorithm::eltwise_sqrt}, - {"abs", dnnl::algorithm::eltwise_abs}, - {"clip", dnnl::algorithm::eltwise_clip}, - {"gelu", dnnl::algorithm::eltwise_gelu_erf}, - {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, - {"relu6", dnnl::algorithm::eltwise_bounded_relu}, - {"sigmoid", dnnl::algorithm::eltwise_logistic}}; - template void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { const auto& dev_ctx = diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 2f4bbfaf74fcc32e816badc904e2ef1c7e4be63f..f1963a75b17293c442ad55cbb0a7cbae5aa2ff64 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -1013,32 +1013,93 @@ class ActivationMKLDNNHandler } }; -static const dnnl::algorithm AcquireActivationAlgorithm( - std::string activation_name) { - std::unordered_map activation_map = { - {"abs", dnnl::algorithm::eltwise_abs}, - {"clip", dnnl::algorithm::eltwise_clip}, - {"gelu", dnnl::algorithm::eltwise_gelu_erf}, - {"gelu_erf", dnnl::algorithm::eltwise_gelu_erf}, - {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, - {"hard_swish", dnnl::algorithm::eltwise_hardswish}, - {"leaky_relu", dnnl::algorithm::eltwise_relu}, - {"mish", dnnl::algorithm::eltwise_mish}, - {"relu", dnnl::algorithm::eltwise_relu}, - {"relu6", dnnl::algorithm::eltwise_bounded_relu}, - {"sigmoid", dnnl::algorithm::eltwise_logistic}, - {"sqrt", dnnl::algorithm::eltwise_sqrt}, - {"swish", dnnl::algorithm::eltwise_swish}, - {"tanh", dnnl::algorithm::eltwise_tanh}}; - - const auto& activation_type = activation_map.find(activation_name); - - PADDLE_ENFORCE_NE(activation_type, - activation_map.end(), - platform::errors::InvalidArgument( - "Activation '%s' not found in oneDNN algorithms mapper", - activation_name)); - return activation_type->second; +static void AppendActivation(const framework::ExecutionContext& ctx, + dnnl::post_ops& post_ops, + float activation_scale = 1.0f) { + const auto invalid_attribute = + ctx.HasAttr("fuse_activation") + ? ctx.Attr("fuse_activation").empty() + : true; + if (invalid_attribute) return; + + const auto fuse_activation = ctx.Attr("fuse_activation"); + const auto fuse_alpha = + ctx.HasAttr("fuse_alpha") ? ctx.Attr("fuse_alpha") : 0.0f; + const auto fuse_beta = + ctx.HasAttr("fuse_beta") ? ctx.Attr("fuse_beta") : 0.0f; + + if (fuse_activation == "hard_sigmoid") { + post_ops.append_eltwise(activation_scale, + dnnl::algorithm::eltwise_linear, + fuse_alpha, + fuse_beta); + post_ops.append_eltwise( + activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f); + } else { + const std::unordered_map activation_map = { + {"abs", dnnl::algorithm::eltwise_abs}, + {"clip", dnnl::algorithm::eltwise_clip}, + {"gelu", dnnl::algorithm::eltwise_gelu_erf}, + {"gelu_erf", dnnl::algorithm::eltwise_gelu_erf}, + {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, + {"hard_swish", dnnl::algorithm::eltwise_hardswish}, + {"leaky_relu", dnnl::algorithm::eltwise_relu}, + {"mish", dnnl::algorithm::eltwise_mish}, + {"relu", dnnl::algorithm::eltwise_relu}, + {"relu6", dnnl::algorithm::eltwise_bounded_relu}, + {"sigmoid", dnnl::algorithm::eltwise_logistic}, + {"sqrt", dnnl::algorithm::eltwise_sqrt}, + {"swish", dnnl::algorithm::eltwise_swish}, + {"tanh", dnnl::algorithm::eltwise_tanh}}; + + const auto& activation_type = activation_map.find(fuse_activation); + + PADDLE_ENFORCE_NE( + activation_type, + activation_map.end(), + platform::errors::InvalidArgument( + "Activation '%s' not found in oneDNN algorithms mapper", + fuse_activation)); + + post_ops.append_eltwise( + activation_scale, activation_type->second, fuse_alpha, fuse_beta); + } +} + +static std::unordered_map GetAttributeMap( + std::string act_type) { + std::unordered_map attr_map; + if (act_type == "swish") + attr_map.emplace("beta", "fuse_alpha"); + else if (act_type == "relu6") + attr_map.emplace("threshold", "fuse_alpha"); + else if (act_type == "hard_sigmoid") { + attr_map.emplace("slope", "fuse_alpha"); + attr_map.emplace("offset", "fuse_beta"); + } else if (act_type == "clip") { + attr_map.emplace("min", "fuse_alpha"); + attr_map.emplace("max", "fuse_beta"); + } else { + attr_map.emplace("alpha", "fuse_alpha"); + attr_map.emplace("beta", "fuse_beta"); + } + return attr_map; +} + +static std::vector GetSupportedActivations() { + return std::vector{"abs", + "clip", + "gelu", + "hard_sigmoid", + "hard_swish", + "leaky_relu", + "mish", + "relu", + "relu6", + "sigmoid", + "sqrt", + "swish", + "tanh"}; } class ReorderMKLDNNHandler { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py index 0c25a790138cd50a0954d209b69f297c941543f1..5e5dd4c719d98b94283e0be677afdecd771fcf5c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py @@ -23,8 +23,8 @@ from paddle.fluid.core import PassVersionChecker class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest): - fuse_activation_alpha = None - fuse_activation_beta = None + fuse_alpha = None + fuse_beta = None pass_name = 'softplus_activation_mkldnn_fuse_pass' def setUp(self): @@ -34,13 +34,13 @@ class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest): shape=[-1, 3, 100, 100], dtype="float32") softplus_out = fluid.layers.softplus(data) - if self.fuse_activation_beta is not None: - activation_out = self.fuse_activation( - softplus_out, self.fuse_activation_alpha, - self.fuse_activation_beta) - elif self.fuse_activation_alpha is not None: - activation_out = self.fuse_activation( - softplus_out, self.fuse_activation_alpha) + if self.fuse_beta is not None: + activation_out = self.fuse_activation(softplus_out, + self.fuse_alpha, + self.fuse_beta) + elif self.fuse_alpha is not None: + activation_out = self.fuse_activation(softplus_out, + self.fuse_alpha) else: activation_out = self.fuse_activation(softplus_out) @@ -73,7 +73,7 @@ class SoftplusActivationLeakyReluOneDNNFusePassTest( def set_params(self): self.fuse_activation = fluid.layers.leaky_relu - self.fuse_activation_alpha = 0.3 + self.fuse_alpha = 0.3 class SoftplusActivationSwishOneDNNFusePassTest( @@ -81,7 +81,7 @@ class SoftplusActivationSwishOneDNNFusePassTest( def set_params(self): self.fuse_activation = fluid.layers.swish - self.fuse_activation_alpha = 3 + self.fuse_alpha = 3 class SoftplusActivationHardSwishOneDNNFusePassTest( @@ -110,8 +110,8 @@ class SoftplusActivationClipOneDNNFusePassTest( def set_params(self): self.fuse_activation = fluid.layers.clip - self.fuse_activation_alpha = 1.1 - self.fuse_activation_beta = 5.2 + self.fuse_alpha = 1.1 + self.fuse_beta = 5.2 class SoftplusActivationGeluErfOneDNNFusePassTest( @@ -126,7 +126,7 @@ class SoftplusActivationGeluTanhOneDNNFusePassTest( def set_params(self): self.fuse_activation = fluid.layers.gelu - self.fuse_activation_alpha = True # simulated "Approximate" attr + self.fuse_alpha = True # simulated "Approximate" attr class SoftplusActivationRelu6OneDNNFusePassTest(