未验证 提交 826e2781 编写于 作者: S Sławomir Siwek 提交者: GitHub

Unify and generalize activation fuse passes (#44185)

* reduce redundancy

* python code style

* fix int8 ut
上级 526be01a
...@@ -931,65 +931,22 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, ...@@ -931,65 +931,22 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
return bn_out_var; return bn_out_var;
} }
PDNode *patterns::ConvActivation::operator()( PDNode *patterns::OperatorActivation::operator()(
paddle::framework::ir::PDNode *conv_input, const std::string &operator_type, const std::string &activation_type) {
std::string conv_type, auto *preceding_op =
std::string activation_type) { pattern->NewNode(preceding_op_repr())->assert_is_op(operator_type);
// Create Operators auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input(conv_type, "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_only_output_of_op(conv_type) ->assert_is_only_output_of_op(operator_type)
->assert_is_op_input(activation_type); ->assert_is_op_input(activation_type);
// output
auto *activation_out_var = pattern->NewNode(activation_out_repr())
->AsOutput()
->assert_is_op_output(activation_type);
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
activation_op->LinksFrom({conv_out_var}).LinksTo({activation_out_var});
return activation_out_var;
}
PDNode *patterns::ElementwiseActivation::operator()(
paddle::framework::ir::PDNode *elementwise_a,
const std::string &elementwise_type,
const std::string &activation_type) {
// Create Operators
elementwise_a->assert_is_op_input(elementwise_type, "X");
auto *elementwise_op =
pattern->NewNode(elementwise_repr())->assert_is_op(elementwise_type);
auto *activation_op = auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type); pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// Create variables auto *activation_out = pattern->NewNode(activation_out_repr())
auto *elementwise_b = pattern->NewNode(elementwise_b_repr())
->AsInput()
->assert_is_op_input(elementwise_type, "Y");
// intermediate variable, will be removed in the IR after fuse.
auto *elementwise_out_var =
pattern->NewNode(elementwise_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op(elementwise_type)
->assert_is_op_input(activation_type);
// output
auto *activation_out_var = pattern->NewNode(activation_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output(activation_type); ->assert_is_op_output(activation_type);
preceding_op->LinksTo({preceding_op_out});
elementwise_op->LinksFrom({elementwise_a, elementwise_b}) activation_op->LinksFrom({preceding_op_out}).LinksTo({activation_out});
.LinksTo({elementwise_out_var}); return activation_out;
activation_op->LinksFrom({elementwise_out_var}).LinksTo({activation_out_var});
return activation_out_var;
} }
PDNode *patterns::SeqConvEltAddRelu::operator()( PDNode *patterns::SeqConvEltAddRelu::operator()(
...@@ -1121,44 +1078,6 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x, ...@@ -1121,44 +1078,6 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
return fc_out_var; return fc_out_var;
} }
PDNode *patterns::FCActOneDNN::operator()(const std::string &act_type) {
auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc");
auto *fc_out = pattern->NewNode(fc_out_repr())
->assert_is_op_output("fc", "Out")
->assert_is_op_input(act_type);
auto *act =
pattern->NewNode(act_repr())->assert_is_op(act_type)->AsIntermediate();
auto *act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output(act_type, "Out")
->AsOutput();
fc->LinksTo({fc_out});
act->LinksFrom({fc_out}).LinksTo({act_out});
return act_out;
}
PDNode *patterns::SoftplusActivation::operator()(std::string activation_type) {
// Create Operators
auto *softplus_op =
pattern->NewNode(softplus_repr())->assert_is_op("softplus");
auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// intermediate variable, will be removed in the IR after fuse.
auto *softplus_out = pattern->NewNode(softplus_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("softplus")
->assert_is_op_input(activation_type);
// output
auto *activation_out = pattern->NewNode(activation_out_repr())
->AsOutput()
->assert_is_op_output(activation_type);
softplus_op->LinksTo({softplus_out});
activation_op->LinksFrom({softplus_out}).LinksTo({activation_out});
return activation_out;
}
PDNode *patterns::Embedding::operator()(PDNode *x) { PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids"); x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op = auto *lookup_table_op =
......
...@@ -524,49 +524,16 @@ struct ConvBN : public PatternBase { ...@@ -524,49 +524,16 @@ struct ConvBN : public PatternBase {
PATTERN_DECL_NODE(bn_saved_variance); PATTERN_DECL_NODE(bn_saved_variance);
}; };
// Conv with Activation struct OperatorActivation : public PatternBase {
// op: conv + activation OperatorActivation(PDPattern* pattern, const std::string& name_scope)
// named nodes: : PatternBase(pattern, name_scope, "operator_activation") {}
// conv_input, conv_weight,
// conv_out, conv,
// activation_out, activation
struct ConvActivation : public PatternBase {
ConvActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_activation") {}
PDNode* operator()(PDNode* conv_input,
std::string conv_type = "conv2d",
std::string activation_type = "relu");
// declare operator node's name PDNode* operator()(const std::string& operator_type,
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(activation);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(activation_out);
};
// Elementwise with Activation
// op: elementwise + activation
// named nodes:
// elementwise_a, elementwise_b,
// elementwise_out, elementwise,
// activation_out, activation
struct ElementwiseActivation : public PatternBase {
ElementwiseActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add_activation") {}
PDNode* operator()(PDNode* elementwise_a,
const std::string& elementwise_type,
const std::string& activation_type); const std::string& activation_type);
// declare operator node's name PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(elementwise); PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(activation); PATTERN_DECL_NODE(activation);
// declare variable node's name
PATTERN_DECL_NODE(elementwise_b);
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(activation_out); PATTERN_DECL_NODE(activation_out);
}; };
...@@ -639,45 +606,6 @@ struct FCMKLDNN : public PatternBase { ...@@ -639,45 +606,6 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE(output); PATTERN_DECL_NODE(output);
}; };
//
// \brief Pattern looking for fc and a directly following activation
// operator.
//
// \note Currently only gelu and tanh are supported as an activation
// function.
// Formula: act(fc(x))
// Op: fc + act
struct FCActOneDNN : public PatternBase {
FCActOneDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_act_onednn") {}
PDNode* operator()(const std::string& act_type);
// declare operator node's name
PATTERN_DECL_NODE(fc);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(fc_out);
PATTERN_DECL_NODE(act_out);
};
// Fuse softplus with activation
// ops: softplus + activation
// nodes:
// softplus, softplus_out,
// activation, activation_out
struct SoftplusActivation : public PatternBase {
SoftplusActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "softplus_activation") {}
PDNode* operator()(std::string activation_type);
// declare operator node's name
PATTERN_DECL_NODE(softplus);
PATTERN_DECL_NODE(activation);
PATTERN_DECL_NODE(softplus_out);
PATTERN_DECL_NODE(activation_out);
};
// Embedding // Embedding
struct Embedding : public PatternBase { struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope) Embedding(PDPattern* pattern, const std::string& name_scope)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -24,61 +25,27 @@ namespace ir { ...@@ -24,61 +25,27 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
std::vector<std::string> act_types = {"relu", auto act_types = paddle::platform::GetSupportedActivations();
"mish",
"swish",
"sqrt",
"hard_swish",
"sigmoid",
"abs",
"gelu",
"relu6",
"clip",
"tanh",
"hard_sigmoid",
"leaky_relu"};
std::vector<std::string> conv_types = {"conv2d"}; std::vector<std::string> conv_types = {"conv2d"};
for (const auto& conv_type : conv_types) for (const auto& conv_type : conv_types)
for (auto& act_type : act_types) { for (auto& act_type : act_types) {
std::unordered_map<std::string, std::string> attrs_map; FuseConvAct(graph, conv_type, act_type);
if (act_type == "swish")
attrs_map.emplace("beta", "fuse_alpha");
else if (act_type == "relu6")
attrs_map.emplace("threshold", "fuse_alpha");
else if (act_type == "hard_sigmoid") {
attrs_map.emplace("slope", "fuse_alpha");
attrs_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attrs_map.emplace("min", "fuse_alpha");
attrs_map.emplace("max", "fuse_beta");
} else {
attrs_map.emplace("alpha", "fuse_alpha");
attrs_map.emplace("beta", "fuse_beta");
}
FuseConvAct(graph, conv_type, act_type, attrs_map);
} }
} }
void ConvActivationMkldnnFusePass::FuseConvAct( void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
Graph* graph,
const std::string& conv_type, const std::string& conv_type,
std::string& act_type, std::string& act_type) const {
const std::unordered_map<std::string, std::string>& attrs_map) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph); FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern() patterns::OperatorActivation conv_act_pattern(gpd.mutable_pattern(),
->NewNode("conv_activation_mkldnn_fuse/conv_input")
->AsInput()
->assert_is_op_input(conv_type, "Input");
patterns::ConvActivation conv_act_pattern(gpd.mutable_pattern(),
"conv_activation_mkldnn_fuse"); "conv_activation_mkldnn_fuse");
conv_act_pattern(conv_input, conv_type, act_type); conv_act_pattern(conv_type, act_type);
int found_conv_activation_count = 0; int found_conv_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -90,16 +57,16 @@ void ConvActivationMkldnnFusePass::FuseConvAct( ...@@ -90,16 +57,16 @@ void ConvActivationMkldnnFusePass::FuseConvAct(
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, conv_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv, preceding_op, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_out, preceding_op_out, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation, conv_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(activation, activation, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern);
OpDesc* conv_op = conv->Op(); OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op(); OpDesc* act_op = activation->Op();
for (const auto& attrs : attrs_map) { auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) { if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first)); conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
} }
......
...@@ -31,11 +31,9 @@ class ConvActivationMkldnnFusePass : public FusePassBase { ...@@ -31,11 +31,9 @@ class ConvActivationMkldnnFusePass : public FusePassBase {
protected: protected:
void ApplyImpl(Graph *graph) const override; void ApplyImpl(Graph *graph) const override;
void FuseConvAct( void FuseConvAct(Graph *graph,
Graph *graph,
const std::string &conv_type, const std::string &conv_type,
std::string &act_type, std::string &act_type) const;
const std::unordered_map<std::string, std::string> &attrs_map) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -26,71 +27,40 @@ namespace ir { ...@@ -26,71 +27,40 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const { void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {"relu", auto act_types = paddle::platform::GetSupportedActivations();
"tanh",
"leaky_relu",
"swish",
"hard_swish",
"sqrt",
"abs",
"clip",
"gelu",
"relu6",
"sigmoid"};
std::vector<std::string> elt_types = { std::vector<std::string> elt_types = {
"elementwise_add", "elementwise_sub", "elementwise_mul"}; "elementwise_add", "elementwise_sub", "elementwise_mul"};
for (const auto &elt_type : elt_types) for (const auto &elt_type : elt_types)
for (const auto &act_type : act_types) { for (const auto &act_type : act_types) {
std::unordered_map<std::string, std::string> attr_map; FuseElementwiseAct(graph, elt_type, act_type);
if (act_type == "swish")
attr_map.emplace("beta", "activation_alpha");
else if (act_type == "relu6")
attr_map.emplace("threshold", "activation_alpha");
else if (act_type == "clip") {
attr_map.emplace("min", "activation_alpha");
attr_map.emplace("max", "activation_beta");
} else {
attr_map.emplace("alpha", "activation_alpha");
attr_map.emplace("beta", "activation_beta");
}
FuseElementwiseAct(graph, elt_type, act_type, attr_map);
} }
} }
void ElementwiseActivationOneDNNPass::FuseElementwiseAct( void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
Graph *graph, Graph *graph,
const std::string &elt_type, const std::string &elt_type,
const std::string &act_type, const std::string &act_type) const {
const std::unordered_map<std::string, std::string> &attr_map) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph); FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto *elementwise_input = gpd.mutable_pattern() patterns::OperatorActivation elementwise_act_pattern(gpd.mutable_pattern(),
->NewNode(elt_type + "_act/elementwise_input")
->AsInput()
->assert_is_op_input(elt_type, "X");
patterns::ElementwiseActivation elementwise_act_pattern(gpd.mutable_pattern(),
elt_type + "_act"); elt_type + "_act");
elementwise_act_pattern(elementwise_input, elt_type, act_type); elementwise_act_pattern(elt_type, act_type);
int found_elementwise_activation_count = 0; int found_elementwise_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
VLOG(4) << "Fuse " << elt_type << " with activation op."; VLOG(4) << "Fuse " << elt_type << " with activation op.";
// Elementwise output
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_out, elementwise_out, elementwise_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, elementwise_act_pattern); elementwise, preceding_op, elementwise_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
elementwise, elementwise, elementwise_act_pattern); elementwise_out, preceding_op_out, elementwise_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation, elementwise_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(activation, activation, elementwise_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, elementwise_act_pattern);
auto *elementwise_op = elementwise->Op(); auto *elementwise_op = elementwise->Op();
...@@ -106,6 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( ...@@ -106,6 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
} }
auto *activation_op = activation->Op(); auto *activation_op = activation->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) { for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) { if (activation_op->HasAttr(attr.first)) {
elementwise_op->SetAttr(attr.second, elementwise_op->SetAttr(attr.second,
...@@ -115,9 +86,9 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( ...@@ -115,9 +86,9 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
if (act_type == "gelu" && activation_op->HasAttr("approximate") && if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
BOOST_GET_CONST(bool, activation_op->GetAttr("approximate"))) BOOST_GET_CONST(bool, activation_op->GetAttr("approximate")))
elementwise_op->SetAttr("activation_type", std::string("gelu_tanh")); elementwise_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else else
elementwise_op->SetAttr("activation_type", act_type); elementwise_op->SetAttr("fuse_activation", act_type);
elementwise_op->SetOutput("Out", {activation_out->Name()}); elementwise_op->SetOutput("Out", {activation_out->Name()});
...@@ -146,14 +117,16 @@ REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass) ...@@ -146,14 +117,16 @@ REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass)
.LE("elementwise_add", 1) .LE("elementwise_add", 1)
.LE("elementwise_sub", 1) .LE("elementwise_sub", 1)
.LE("elementwise_mul", 1) .LE("elementwise_mul", 1)
.LE("relu", 0) .EQ("abs", 0)
.LE("tanh", 0)
.LE("leaky_relu", 1)
.LE("swish", 0)
.LE("hard_swish", 0)
.LE("sqrt", 0)
.LE("abs", 0)
.LE("clip", 1) .LE("clip", 1)
.LE("gelu", 0) .EQ("gelu", 0)
.LE("relu6", 0) .EQ("hard_sigmoid", 0)
.LE("sigmoid", 0)); .LE("hard_swish", 0)
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
...@@ -34,11 +34,9 @@ class ElementwiseActivationOneDNNPass : public FusePassBase { ...@@ -34,11 +34,9 @@ class ElementwiseActivationOneDNNPass : public FusePassBase {
protected: protected:
void ApplyImpl(Graph *graph) const override; void ApplyImpl(Graph *graph) const override;
void FuseElementwiseAct( void FuseElementwiseAct(Graph *graph,
Graph *graph,
const std::string &elt_types, const std::string &elt_types,
const std::string &act_types, const std::string &act_types) const;
const std::unordered_map<std::string, std::string> &attr_map) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -39,20 +39,17 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, ...@@ -39,20 +39,17 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
FusePassBase::Init("fc_act", graph); FusePassBase::Init("fc_act", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FCActOneDNN fc_act_pattern(gpd.mutable_pattern(), "fc_act"); patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act");
fc_act_pattern(act_type); fc_act_pattern("fc", act_type);
int found_fc_act_count = 0; int found_fc_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
VLOG(4) << "Fuse fc with activation op."; VLOG(4) << "Fuse fc with activation op.";
// FC output GET_IR_NODE_FROM_SUBGRAPH(fc, preceding_op, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, preceding_op_out, fc_act_pattern);
// ACT output GET_IR_NODE_FROM_SUBGRAPH(act, activation, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fc_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(act_out, activation_out, fc_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, fc_act_pattern);
auto *fc_op = fc->Op(); auto *fc_op = fc->Op();
auto *act_op = act->Op(); auto *act_op = act->Op();
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -26,59 +27,34 @@ namespace ir { ...@@ -26,59 +27,34 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const { void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {"relu", auto act_types = paddle::platform::GetSupportedActivations();
"tanh",
"leaky_relu",
"swish",
"hardswish",
"sqrt",
"abs",
"clip",
"gelu",
"relu6",
"sigmoid"};
for (const auto &act_type : act_types) { for (const auto &act_type : act_types) {
std::unordered_map<std::string, std::string> attr_map; FuseSoftplusActivation(graph, act_type);
if (act_type == "swish")
attr_map.emplace("beta", "fuse_activation_alpha");
else if (act_type == "relu6")
attr_map.emplace("threshold", "fuse_activation_alpha");
else if (act_type == "clip") {
attr_map.emplace("min", "fuse_activation_alpha");
attr_map.emplace("max", "fuse_activation_beta");
} else {
attr_map.emplace("alpha", "fuse_activation_alpha");
attr_map.emplace("beta", "fuse_activation_beta");
}
FuseSoftplusActivation(graph, act_type, attr_map);
} }
} }
void SoftplusActivationOneDNNPass::FuseSoftplusActivation( void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
Graph *graph, Graph *graph, const std::string &act_type) const {
const std::string &fuse_activation_type,
const std::unordered_map<std::string, std::string> &attr_map) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("softplus_activation", graph); FusePassBase::Init("softplus_activation", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::SoftplusActivation softplus_activation_pattern( patterns::OperatorActivation softplus_activation_pattern(
gpd.mutable_pattern(), "softplus_activation"); gpd.mutable_pattern(), "softplus_activation");
softplus_activation_pattern(fuse_activation_type); softplus_activation_pattern("softplus", act_type);
int found_softplus_activation_count = 0; int found_softplus_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
VLOG(4) << "Fuse softplus with activation op."; VLOG(4) << "Fuse softplus with activation op.";
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
softplus_out, softplus_out, softplus_activation_pattern); softplus_out, preceding_op_out, softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, softplus_activation_pattern); activation_out, activation_out, softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
GET_IR_NODE_FROM_SUBGRAPH(softplus, softplus, softplus_activation_pattern); softplus, preceding_op, softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
activation, activation, softplus_activation_pattern); activation, activation, softplus_activation_pattern);
...@@ -94,18 +70,18 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( ...@@ -94,18 +70,18 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
} }
auto *activation_op = activation->Op(); auto *activation_op = activation->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) { for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) { if (activation_op->HasAttr(attr.first)) {
softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first)); softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first));
} }
} }
if (fuse_activation_type == "gelu" && if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
activation_op->HasAttr("approximate") &&
BOOST_GET_CONST(bool, activation_op->GetAttr("approximate"))) BOOST_GET_CONST(bool, activation_op->GetAttr("approximate")))
softplus_op->SetAttr("fuse_activation_type", std::string("gelu_tanh")); softplus_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else else
softplus_op->SetAttr("fuse_activation_type", fuse_activation_type); softplus_op->SetAttr("fuse_activation", act_type);
softplus_op->SetAttr("use_mkldnn", true); softplus_op->SetAttr("use_mkldnn", true);
...@@ -121,7 +97,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( ...@@ -121,7 +97,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
if (!Has("disable_logs") || !Get<bool>("disable_logs")) if (!Has("disable_logs") || !Get<bool>("disable_logs"))
PrettyLogDetail("--- fused %d softplus with %s activation", PrettyLogDetail("--- fused %d softplus with %s activation",
found_softplus_activation_count, found_softplus_activation_count,
fuse_activation_type); act_type);
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
...@@ -133,13 +109,16 @@ REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass) ...@@ -133,13 +109,16 @@ REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("softplus", 1) .LE("softplus", 1)
.EQ("relu", 0)
.EQ("tanh", 0)
.LE("leaky_relu", 1)
.EQ("swish", 0)
.EQ("hard_swish", 0)
.EQ("sqrt", 0)
.EQ("abs", 0) .EQ("abs", 0)
.LE("relu6", 1)
.LE("clip", 1) .LE("clip", 1)
.EQ("gelu", 0)); .EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
...@@ -34,10 +34,8 @@ class SoftplusActivationOneDNNPass : public FusePassBase { ...@@ -34,10 +34,8 @@ class SoftplusActivationOneDNNPass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph *graph) const override; void ApplyImpl(ir::Graph *graph) const override;
void FuseSoftplusActivation( void FuseSoftplusActivation(ir::Graph *graph,
ir::Graph *graph, const std::string &act_type) const;
const std::string &fuse_activation_type,
const std::unordered_map<std::string, std::string> &attr_map) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -50,22 +50,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -50,22 +50,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
private: private:
dnnl::post_ops get_post_ops(const framework::ExecutionContext& ctx) const { dnnl::post_ops get_post_ops(const framework::ExecutionContext& ctx) const {
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
if (ctx.HasAttr("activation_type")) { platform::AppendActivation(ctx, post_operations);
const float scale = ctx.HasAttr("activation_scale")
? ctx.Attr<float>("activation_scale")
: 1.0f;
const float alpha = ctx.HasAttr("activation_alpha")
? ctx.Attr<float>("activation_alpha")
: 0.0f;
const float beta = ctx.HasAttr("activation_beta")
? ctx.Attr<float>("activation_beta")
: 0.0f;
const auto activation_algorithm = platform::AcquireActivationAlgorithm(
ctx.Attr<std::string>("activation_type"));
post_operations.append_eltwise(scale, activation_algorithm, alpha, beta);
}
return post_operations; return post_operations;
} }
......
...@@ -553,10 +553,6 @@ class ConvMKLDNNHandlerT ...@@ -553,10 +553,6 @@ class ConvMKLDNNHandlerT
dnnl::primitive_attr conv_attr; dnnl::primitive_attr conv_attr;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
float sum_scale = 1.0f; float sum_scale = 1.0f;
...@@ -587,19 +583,7 @@ class ConvMKLDNNHandlerT ...@@ -587,19 +583,7 @@ class ConvMKLDNNHandlerT
post_operations.append_sum(sum_scale); post_operations.append_sum(sum_scale);
} }
if (fuse_activation == "hard_sigmoid") { platform::AppendActivation(ctx, post_operations, activation_scale);
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha,
fuse_beta);
post_operations.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else if (fuse_activation != "") {
const auto activation_algorithm =
platform::AcquireActivationAlgorithm(fuse_activation);
post_operations.append_eltwise(
activation_scale, activation_algorithm, fuse_alpha, fuse_beta);
}
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
......
...@@ -46,7 +46,7 @@ class SoftplusMKLDNNHandler ...@@ -46,7 +46,7 @@ class SoftplusMKLDNNHandler
1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f); 1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f);
} }
AppendFusedActivationIfExists(ctx, &post_ops); platform::AppendActivation(ctx, post_ops);
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
attrs.set_post_ops(post_ops); attrs.set_post_ops(post_ops);
...@@ -62,42 +62,8 @@ class SoftplusMKLDNNHandler ...@@ -62,42 +62,8 @@ class SoftplusMKLDNNHandler
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src1_desc(), platform::to_void_cast<float>(beta)); this->fwd_pd_->src1_desc(), platform::to_void_cast<float>(beta));
} }
private:
void AppendFusedActivationIfExists(const framework::ExecutionContext& ctx,
dnnl::post_ops* post_ops) {
const auto& fused_activation_type =
algo_map.find(ctx.Attr<std::string>("fuse_activation_type"));
if (fused_activation_type != algo_map.end()) {
auto scale_out =
ctx.Attr<float>("fuse_activation_scale"); // for future int8 support
post_ops->append_eltwise(scale_out,
fused_activation_type->second,
ctx.Attr<float>("fuse_activation_alpha"),
ctx.Attr<float>("fuse_activation_beta"));
}
}
static const std::unordered_map<std::string, dnnl::algorithm> algo_map;
}; };
template <typename T>
const std::unordered_map<std::string, dnnl::algorithm>
SoftplusMKLDNNHandler<T>::algo_map = {
{"relu", dnnl::algorithm::eltwise_relu},
{"tanh", dnnl::algorithm::eltwise_tanh},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"swish", dnnl::algorithm::eltwise_swish},
{"hardswish", dnnl::algorithm::eltwise_hardswish},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic}};
template <typename T> template <typename T>
void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) {
const auto& dev_ctx = const auto& dev_ctx =
......
...@@ -1013,9 +1013,30 @@ class ActivationMKLDNNHandler ...@@ -1013,9 +1013,30 @@ class ActivationMKLDNNHandler
} }
}; };
static const dnnl::algorithm AcquireActivationAlgorithm( static void AppendActivation(const framework::ExecutionContext& ctx,
std::string activation_name) { dnnl::post_ops& post_ops,
std::unordered_map<std::string, dnnl::algorithm> activation_map = { float activation_scale = 1.0f) {
const auto invalid_attribute =
ctx.HasAttr("fuse_activation")
? ctx.Attr<std::string>("fuse_activation").empty()
: true;
if (invalid_attribute) return;
const auto fuse_activation = ctx.Attr<std::string>("fuse_activation");
const auto fuse_alpha =
ctx.HasAttr("fuse_alpha") ? ctx.Attr<float>("fuse_alpha") : 0.0f;
const auto fuse_beta =
ctx.HasAttr("fuse_beta") ? ctx.Attr<float>("fuse_beta") : 0.0f;
if (fuse_activation == "hard_sigmoid") {
post_ops.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha,
fuse_beta);
post_ops.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else {
const std::unordered_map<std::string, dnnl::algorithm> activation_map = {
{"abs", dnnl::algorithm::eltwise_abs}, {"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip}, {"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf}, {"gelu", dnnl::algorithm::eltwise_gelu_erf},
...@@ -1031,14 +1052,54 @@ static const dnnl::algorithm AcquireActivationAlgorithm( ...@@ -1031,14 +1052,54 @@ static const dnnl::algorithm AcquireActivationAlgorithm(
{"swish", dnnl::algorithm::eltwise_swish}, {"swish", dnnl::algorithm::eltwise_swish},
{"tanh", dnnl::algorithm::eltwise_tanh}}; {"tanh", dnnl::algorithm::eltwise_tanh}};
const auto& activation_type = activation_map.find(activation_name); const auto& activation_type = activation_map.find(fuse_activation);
PADDLE_ENFORCE_NE(activation_type, PADDLE_ENFORCE_NE(
activation_type,
activation_map.end(), activation_map.end(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Activation '%s' not found in oneDNN algorithms mapper", "Activation '%s' not found in oneDNN algorithms mapper",
activation_name)); fuse_activation));
return activation_type->second;
post_ops.append_eltwise(
activation_scale, activation_type->second, fuse_alpha, fuse_beta);
}
}
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish")
attr_map.emplace("beta", "fuse_alpha");
else if (act_type == "relu6")
attr_map.emplace("threshold", "fuse_alpha");
else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
static std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
} }
class ReorderMKLDNNHandler { class ReorderMKLDNNHandler {
......
...@@ -23,8 +23,8 @@ from paddle.fluid.core import PassVersionChecker ...@@ -23,8 +23,8 @@ from paddle.fluid.core import PassVersionChecker
class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest): class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
fuse_activation_alpha = None fuse_alpha = None
fuse_activation_beta = None fuse_beta = None
pass_name = 'softplus_activation_mkldnn_fuse_pass' pass_name = 'softplus_activation_mkldnn_fuse_pass'
def setUp(self): def setUp(self):
...@@ -34,13 +34,13 @@ class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest): ...@@ -34,13 +34,13 @@ class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
shape=[-1, 3, 100, 100], shape=[-1, 3, 100, 100],
dtype="float32") dtype="float32")
softplus_out = fluid.layers.softplus(data) softplus_out = fluid.layers.softplus(data)
if self.fuse_activation_beta is not None: if self.fuse_beta is not None:
activation_out = self.fuse_activation( activation_out = self.fuse_activation(softplus_out,
softplus_out, self.fuse_activation_alpha, self.fuse_alpha,
self.fuse_activation_beta) self.fuse_beta)
elif self.fuse_activation_alpha is not None: elif self.fuse_alpha is not None:
activation_out = self.fuse_activation( activation_out = self.fuse_activation(softplus_out,
softplus_out, self.fuse_activation_alpha) self.fuse_alpha)
else: else:
activation_out = self.fuse_activation(softplus_out) activation_out = self.fuse_activation(softplus_out)
...@@ -73,7 +73,7 @@ class SoftplusActivationLeakyReluOneDNNFusePassTest( ...@@ -73,7 +73,7 @@ class SoftplusActivationLeakyReluOneDNNFusePassTest(
def set_params(self): def set_params(self):
self.fuse_activation = fluid.layers.leaky_relu self.fuse_activation = fluid.layers.leaky_relu
self.fuse_activation_alpha = 0.3 self.fuse_alpha = 0.3
class SoftplusActivationSwishOneDNNFusePassTest( class SoftplusActivationSwishOneDNNFusePassTest(
...@@ -81,7 +81,7 @@ class SoftplusActivationSwishOneDNNFusePassTest( ...@@ -81,7 +81,7 @@ class SoftplusActivationSwishOneDNNFusePassTest(
def set_params(self): def set_params(self):
self.fuse_activation = fluid.layers.swish self.fuse_activation = fluid.layers.swish
self.fuse_activation_alpha = 3 self.fuse_alpha = 3
class SoftplusActivationHardSwishOneDNNFusePassTest( class SoftplusActivationHardSwishOneDNNFusePassTest(
...@@ -110,8 +110,8 @@ class SoftplusActivationClipOneDNNFusePassTest( ...@@ -110,8 +110,8 @@ class SoftplusActivationClipOneDNNFusePassTest(
def set_params(self): def set_params(self):
self.fuse_activation = fluid.layers.clip self.fuse_activation = fluid.layers.clip
self.fuse_activation_alpha = 1.1 self.fuse_alpha = 1.1
self.fuse_activation_beta = 5.2 self.fuse_beta = 5.2
class SoftplusActivationGeluErfOneDNNFusePassTest( class SoftplusActivationGeluErfOneDNNFusePassTest(
...@@ -126,7 +126,7 @@ class SoftplusActivationGeluTanhOneDNNFusePassTest( ...@@ -126,7 +126,7 @@ class SoftplusActivationGeluTanhOneDNNFusePassTest(
def set_params(self): def set_params(self):
self.fuse_activation = fluid.layers.gelu self.fuse_activation = fluid.layers.gelu
self.fuse_activation_alpha = True # simulated "Approximate" attr self.fuse_alpha = True # simulated "Approximate" attr
class SoftplusActivationRelu6OneDNNFusePassTest( class SoftplusActivationRelu6OneDNNFusePassTest(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册