未验证 提交 826e2781 编写于 作者: S Sławomir Siwek 提交者: GitHub

Unify and generalize activation fuse passes (#44185)

* reduce redundancy

* python code style

* fix int8 ut
上级 526be01a
......@@ -931,65 +931,22 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
return bn_out_var;
}
PDNode *patterns::ConvActivation::operator()(
paddle::framework::ir::PDNode *conv_input,
std::string conv_type,
std::string activation_type) {
// Create Operators
conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input(conv_type, "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
PDNode *patterns::OperatorActivation::operator()(
const std::string &operator_type, const std::string &activation_type) {
auto *preceding_op =
pattern->NewNode(preceding_op_repr())->assert_is_op(operator_type);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op(conv_type)
->assert_is_only_output_of_op(operator_type)
->assert_is_op_input(activation_type);
// output
auto *activation_out_var = pattern->NewNode(activation_out_repr())
->AsOutput()
->assert_is_op_output(activation_type);
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
activation_op->LinksFrom({conv_out_var}).LinksTo({activation_out_var});
return activation_out_var;
}
PDNode *patterns::ElementwiseActivation::operator()(
paddle::framework::ir::PDNode *elementwise_a,
const std::string &elementwise_type,
const std::string &activation_type) {
// Create Operators
elementwise_a->assert_is_op_input(elementwise_type, "X");
auto *elementwise_op =
pattern->NewNode(elementwise_repr())->assert_is_op(elementwise_type);
auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// Create variables
auto *elementwise_b = pattern->NewNode(elementwise_b_repr())
->AsInput()
->assert_is_op_input(elementwise_type, "Y");
// intermediate variable, will be removed in the IR after fuse.
auto *elementwise_out_var =
pattern->NewNode(elementwise_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op(elementwise_type)
->assert_is_op_input(activation_type);
// output
auto *activation_out_var = pattern->NewNode(activation_out_repr())
auto *activation_out = pattern->NewNode(activation_out_repr())
->AsOutput()
->assert_is_op_output(activation_type);
elementwise_op->LinksFrom({elementwise_a, elementwise_b})
.LinksTo({elementwise_out_var});
activation_op->LinksFrom({elementwise_out_var}).LinksTo({activation_out_var});
return activation_out_var;
preceding_op->LinksTo({preceding_op_out});
activation_op->LinksFrom({preceding_op_out}).LinksTo({activation_out});
return activation_out;
}
PDNode *patterns::SeqConvEltAddRelu::operator()(
......@@ -1121,44 +1078,6 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
return fc_out_var;
}
PDNode *patterns::FCActOneDNN::operator()(const std::string &act_type) {
auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc");
auto *fc_out = pattern->NewNode(fc_out_repr())
->assert_is_op_output("fc", "Out")
->assert_is_op_input(act_type);
auto *act =
pattern->NewNode(act_repr())->assert_is_op(act_type)->AsIntermediate();
auto *act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output(act_type, "Out")
->AsOutput();
fc->LinksTo({fc_out});
act->LinksFrom({fc_out}).LinksTo({act_out});
return act_out;
}
PDNode *patterns::SoftplusActivation::operator()(std::string activation_type) {
// Create Operators
auto *softplus_op =
pattern->NewNode(softplus_repr())->assert_is_op("softplus");
auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// intermediate variable, will be removed in the IR after fuse.
auto *softplus_out = pattern->NewNode(softplus_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("softplus")
->assert_is_op_input(activation_type);
// output
auto *activation_out = pattern->NewNode(activation_out_repr())
->AsOutput()
->assert_is_op_output(activation_type);
softplus_op->LinksTo({softplus_out});
activation_op->LinksFrom({softplus_out}).LinksTo({activation_out});
return activation_out;
}
PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op =
......
......@@ -524,49 +524,16 @@ struct ConvBN : public PatternBase {
PATTERN_DECL_NODE(bn_saved_variance);
};
// Conv with Activation
// op: conv + activation
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// activation_out, activation
struct ConvActivation : public PatternBase {
ConvActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_activation") {}
PDNode* operator()(PDNode* conv_input,
std::string conv_type = "conv2d",
std::string activation_type = "relu");
struct OperatorActivation : public PatternBase {
OperatorActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_activation") {}
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(activation);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(activation_out);
};
// Elementwise with Activation
// op: elementwise + activation
// named nodes:
// elementwise_a, elementwise_b,
// elementwise_out, elementwise,
// activation_out, activation
struct ElementwiseActivation : public PatternBase {
ElementwiseActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add_activation") {}
PDNode* operator()(PDNode* elementwise_a,
const std::string& elementwise_type,
PDNode* operator()(const std::string& operator_type,
const std::string& activation_type);
// declare operator node's name
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(activation);
// declare variable node's name
PATTERN_DECL_NODE(elementwise_b);
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(activation_out);
};
......@@ -639,45 +606,6 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE(output);
};
//
// \brief Pattern looking for fc and a directly following activation
// operator.
//
// \note Currently only gelu and tanh are supported as an activation
// function.
// Formula: act(fc(x))
// Op: fc + act
struct FCActOneDNN : public PatternBase {
FCActOneDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_act_onednn") {}
PDNode* operator()(const std::string& act_type);
// declare operator node's name
PATTERN_DECL_NODE(fc);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(fc_out);
PATTERN_DECL_NODE(act_out);
};
// Fuse softplus with activation
// ops: softplus + activation
// nodes:
// softplus, softplus_out,
// activation, activation_out
struct SoftplusActivation : public PatternBase {
SoftplusActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "softplus_activation") {}
PDNode* operator()(std::string activation_type);
// declare operator node's name
PATTERN_DECL_NODE(softplus);
PATTERN_DECL_NODE(activation);
PATTERN_DECL_NODE(softplus_out);
PATTERN_DECL_NODE(activation_out);
};
// Embedding
struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope)
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -24,61 +25,27 @@ namespace ir {
using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
std::vector<std::string> act_types = {"relu",
"mish",
"swish",
"sqrt",
"hard_swish",
"sigmoid",
"abs",
"gelu",
"relu6",
"clip",
"tanh",
"hard_sigmoid",
"leaky_relu"};
auto act_types = paddle::platform::GetSupportedActivations();
std::vector<std::string> conv_types = {"conv2d"};
for (const auto& conv_type : conv_types)
for (auto& act_type : act_types) {
std::unordered_map<std::string, std::string> attrs_map;
if (act_type == "swish")
attrs_map.emplace("beta", "fuse_alpha");
else if (act_type == "relu6")
attrs_map.emplace("threshold", "fuse_alpha");
else if (act_type == "hard_sigmoid") {
attrs_map.emplace("slope", "fuse_alpha");
attrs_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attrs_map.emplace("min", "fuse_alpha");
attrs_map.emplace("max", "fuse_beta");
} else {
attrs_map.emplace("alpha", "fuse_alpha");
attrs_map.emplace("beta", "fuse_beta");
}
FuseConvAct(graph, conv_type, act_type, attrs_map);
FuseConvAct(graph, conv_type, act_type);
}
}
void ConvActivationMkldnnFusePass::FuseConvAct(
Graph* graph,
void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
const std::string& conv_type,
std::string& act_type,
const std::unordered_map<std::string, std::string>& attrs_map) const {
std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_activation_mkldnn_fuse/conv_input")
->AsInput()
->assert_is_op_input(conv_type, "Input");
patterns::ConvActivation conv_act_pattern(gpd.mutable_pattern(),
patterns::OperatorActivation conv_act_pattern(gpd.mutable_pattern(),
"conv_activation_mkldnn_fuse");
conv_act_pattern(conv_input, conv_type, act_type);
conv_act_pattern(conv_type, act_type);
int found_conv_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......@@ -90,16 +57,16 @@ void ConvActivationMkldnnFusePass::FuseConvAct(
return;
}
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv, preceding_op, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, preceding_op_out, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern);
OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op();
for (const auto& attrs : attrs_map) {
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
......
......@@ -31,11 +31,9 @@ class ConvActivationMkldnnFusePass : public FusePassBase {
protected:
void ApplyImpl(Graph *graph) const override;
void FuseConvAct(
Graph *graph,
void FuseConvAct(Graph *graph,
const std::string &conv_type,
std::string &act_type,
const std::unordered_map<std::string, std::string> &attrs_map) const;
std::string &act_type) const;
};
} // namespace ir
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -26,71 +27,40 @@ namespace ir {
using string::PrettyLogDetail;
void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {"relu",
"tanh",
"leaky_relu",
"swish",
"hard_swish",
"sqrt",
"abs",
"clip",
"gelu",
"relu6",
"sigmoid"};
auto act_types = paddle::platform::GetSupportedActivations();
std::vector<std::string> elt_types = {
"elementwise_add", "elementwise_sub", "elementwise_mul"};
for (const auto &elt_type : elt_types)
for (const auto &act_type : act_types) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish")
attr_map.emplace("beta", "activation_alpha");
else if (act_type == "relu6")
attr_map.emplace("threshold", "activation_alpha");
else if (act_type == "clip") {
attr_map.emplace("min", "activation_alpha");
attr_map.emplace("max", "activation_beta");
} else {
attr_map.emplace("alpha", "activation_alpha");
attr_map.emplace("beta", "activation_beta");
}
FuseElementwiseAct(graph, elt_type, act_type, attr_map);
FuseElementwiseAct(graph, elt_type, act_type);
}
}
void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
Graph *graph,
const std::string &elt_type,
const std::string &act_type,
const std::unordered_map<std::string, std::string> &attr_map) const {
const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
auto *elementwise_input = gpd.mutable_pattern()
->NewNode(elt_type + "_act/elementwise_input")
->AsInput()
->assert_is_op_input(elt_type, "X");
patterns::ElementwiseActivation elementwise_act_pattern(gpd.mutable_pattern(),
patterns::OperatorActivation elementwise_act_pattern(gpd.mutable_pattern(),
elt_type + "_act");
elementwise_act_pattern(elementwise_input, elt_type, act_type);
elementwise_act_pattern(elt_type, act_type);
int found_elementwise_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse " << elt_type << " with activation op.";
// Elementwise output
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_out, elementwise_out, elementwise_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, elementwise_act_pattern);
// ops
elementwise, preceding_op, elementwise_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise, elementwise, elementwise_act_pattern);
elementwise_out, preceding_op_out, elementwise_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation, elementwise_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, elementwise_act_pattern);
auto *elementwise_op = elementwise->Op();
......@@ -106,6 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
}
auto *activation_op = activation->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
elementwise_op->SetAttr(attr.second,
......@@ -115,9 +86,9 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
BOOST_GET_CONST(bool, activation_op->GetAttr("approximate")))
elementwise_op->SetAttr("activation_type", std::string("gelu_tanh"));
elementwise_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else
elementwise_op->SetAttr("activation_type", act_type);
elementwise_op->SetAttr("fuse_activation", act_type);
elementwise_op->SetOutput("Out", {activation_out->Name()});
......@@ -146,14 +117,16 @@ REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass)
.LE("elementwise_add", 1)
.LE("elementwise_sub", 1)
.LE("elementwise_mul", 1)
.LE("relu", 0)
.LE("tanh", 0)
.LE("leaky_relu", 1)
.LE("swish", 0)
.LE("hard_swish", 0)
.LE("sqrt", 0)
.LE("abs", 0)
.EQ("abs", 0)
.LE("clip", 1)
.LE("gelu", 0)
.LE("relu6", 0)
.LE("sigmoid", 0));
.EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
......@@ -34,11 +34,9 @@ class ElementwiseActivationOneDNNPass : public FusePassBase {
protected:
void ApplyImpl(Graph *graph) const override;
void FuseElementwiseAct(
Graph *graph,
void FuseElementwiseAct(Graph *graph,
const std::string &elt_types,
const std::string &act_types,
const std::unordered_map<std::string, std::string> &attr_map) const;
const std::string &act_types) const;
};
} // namespace ir
......
......@@ -39,20 +39,17 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
FusePassBase::Init("fc_act", graph);
GraphPatternDetector gpd;
patterns::FCActOneDNN fc_act_pattern(gpd.mutable_pattern(), "fc_act");
fc_act_pattern(act_type);
patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act");
fc_act_pattern("fc", act_type);
int found_fc_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse fc with activation op.";
// FC output
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fc_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc, preceding_op, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, preceding_op_out, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, activation, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, activation_out, fc_act_pattern);
auto *fc_op = fc->Op();
auto *act_op = act->Op();
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -26,59 +27,34 @@ namespace ir {
using string::PrettyLogDetail;
void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {"relu",
"tanh",
"leaky_relu",
"swish",
"hardswish",
"sqrt",
"abs",
"clip",
"gelu",
"relu6",
"sigmoid"};
auto act_types = paddle::platform::GetSupportedActivations();
for (const auto &act_type : act_types) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish")
attr_map.emplace("beta", "fuse_activation_alpha");
else if (act_type == "relu6")
attr_map.emplace("threshold", "fuse_activation_alpha");
else if (act_type == "clip") {
attr_map.emplace("min", "fuse_activation_alpha");
attr_map.emplace("max", "fuse_activation_beta");
} else {
attr_map.emplace("alpha", "fuse_activation_alpha");
attr_map.emplace("beta", "fuse_activation_beta");
}
FuseSoftplusActivation(graph, act_type, attr_map);
FuseSoftplusActivation(graph, act_type);
}
}
void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
Graph *graph,
const std::string &fuse_activation_type,
const std::unordered_map<std::string, std::string> &attr_map) const {
Graph *graph, const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("softplus_activation", graph);
GraphPatternDetector gpd;
patterns::SoftplusActivation softplus_activation_pattern(
patterns::OperatorActivation softplus_activation_pattern(
gpd.mutable_pattern(), "softplus_activation");
softplus_activation_pattern(fuse_activation_type);
softplus_activation_pattern("softplus", act_type);
int found_softplus_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse softplus with activation op.";
GET_IR_NODE_FROM_SUBGRAPH(
softplus_out, softplus_out, softplus_activation_pattern);
softplus_out, preceding_op_out, softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softplus, softplus, softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softplus, preceding_op, softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
activation, activation, softplus_activation_pattern);
......@@ -94,18 +70,18 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
}
auto *activation_op = activation->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first));
}
}
if (fuse_activation_type == "gelu" &&
activation_op->HasAttr("approximate") &&
if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
BOOST_GET_CONST(bool, activation_op->GetAttr("approximate")))
softplus_op->SetAttr("fuse_activation_type", std::string("gelu_tanh"));
softplus_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else
softplus_op->SetAttr("fuse_activation_type", fuse_activation_type);
softplus_op->SetAttr("fuse_activation", act_type);
softplus_op->SetAttr("use_mkldnn", true);
......@@ -121,7 +97,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
if (!Has("disable_logs") || !Get<bool>("disable_logs"))
PrettyLogDetail("--- fused %d softplus with %s activation",
found_softplus_activation_count,
fuse_activation_type);
act_type);
}
} // namespace ir
} // namespace framework
......@@ -133,13 +109,16 @@ REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("softplus", 1)
.EQ("relu", 0)
.EQ("tanh", 0)
.LE("leaky_relu", 1)
.EQ("swish", 0)
.EQ("hard_swish", 0)
.EQ("sqrt", 0)
.EQ("abs", 0)
.LE("relu6", 1)
.LE("clip", 1)
.EQ("gelu", 0));
.EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
......@@ -34,10 +34,8 @@ class SoftplusActivationOneDNNPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph *graph) const override;
void FuseSoftplusActivation(
ir::Graph *graph,
const std::string &fuse_activation_type,
const std::unordered_map<std::string, std::string> &attr_map) const;
void FuseSoftplusActivation(ir::Graph *graph,
const std::string &act_type) const;
};
} // namespace ir
......
......@@ -50,22 +50,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
private:
dnnl::post_ops get_post_ops(const framework::ExecutionContext& ctx) const {
dnnl::post_ops post_operations;
if (ctx.HasAttr("activation_type")) {
const float scale = ctx.HasAttr("activation_scale")
? ctx.Attr<float>("activation_scale")
: 1.0f;
const float alpha = ctx.HasAttr("activation_alpha")
? ctx.Attr<float>("activation_alpha")
: 0.0f;
const float beta = ctx.HasAttr("activation_beta")
? ctx.Attr<float>("activation_beta")
: 0.0f;
const auto activation_algorithm = platform::AcquireActivationAlgorithm(
ctx.Attr<std::string>("activation_type"));
post_operations.append_eltwise(scale, activation_algorithm, alpha, beta);
}
platform::AppendActivation(ctx, post_operations);
return post_operations;
}
......
......@@ -553,10 +553,6 @@ class ConvMKLDNNHandlerT
dnnl::primitive_attr conv_attr;
dnnl::post_ops post_operations;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
float sum_scale = 1.0f;
......@@ -587,19 +583,7 @@ class ConvMKLDNNHandlerT
post_operations.append_sum(sum_scale);
}
if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha,
fuse_beta);
post_operations.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else if (fuse_activation != "") {
const auto activation_algorithm =
platform::AcquireActivationAlgorithm(fuse_activation);
post_operations.append_eltwise(
activation_scale, activation_algorithm, fuse_alpha, fuse_beta);
}
platform::AppendActivation(ctx, post_operations, activation_scale);
conv_attr.set_post_ops(post_operations);
return conv_attr;
......
......@@ -46,7 +46,7 @@ class SoftplusMKLDNNHandler
1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f);
}
AppendFusedActivationIfExists(ctx, &post_ops);
platform::AppendActivation(ctx, post_ops);
dnnl::primitive_attr attrs;
attrs.set_post_ops(post_ops);
......@@ -62,42 +62,8 @@ class SoftplusMKLDNNHandler
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src1_desc(), platform::to_void_cast<float>(beta));
}
private:
void AppendFusedActivationIfExists(const framework::ExecutionContext& ctx,
dnnl::post_ops* post_ops) {
const auto& fused_activation_type =
algo_map.find(ctx.Attr<std::string>("fuse_activation_type"));
if (fused_activation_type != algo_map.end()) {
auto scale_out =
ctx.Attr<float>("fuse_activation_scale"); // for future int8 support
post_ops->append_eltwise(scale_out,
fused_activation_type->second,
ctx.Attr<float>("fuse_activation_alpha"),
ctx.Attr<float>("fuse_activation_beta"));
}
}
static const std::unordered_map<std::string, dnnl::algorithm> algo_map;
};
template <typename T>
const std::unordered_map<std::string, dnnl::algorithm>
SoftplusMKLDNNHandler<T>::algo_map = {
{"relu", dnnl::algorithm::eltwise_relu},
{"tanh", dnnl::algorithm::eltwise_tanh},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"swish", dnnl::algorithm::eltwise_swish},
{"hardswish", dnnl::algorithm::eltwise_hardswish},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic}};
template <typename T>
void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) {
const auto& dev_ctx =
......
......@@ -1013,9 +1013,30 @@ class ActivationMKLDNNHandler
}
};
static const dnnl::algorithm AcquireActivationAlgorithm(
std::string activation_name) {
std::unordered_map<std::string, dnnl::algorithm> activation_map = {
static void AppendActivation(const framework::ExecutionContext& ctx,
dnnl::post_ops& post_ops,
float activation_scale = 1.0f) {
const auto invalid_attribute =
ctx.HasAttr("fuse_activation")
? ctx.Attr<std::string>("fuse_activation").empty()
: true;
if (invalid_attribute) return;
const auto fuse_activation = ctx.Attr<std::string>("fuse_activation");
const auto fuse_alpha =
ctx.HasAttr("fuse_alpha") ? ctx.Attr<float>("fuse_alpha") : 0.0f;
const auto fuse_beta =
ctx.HasAttr("fuse_beta") ? ctx.Attr<float>("fuse_beta") : 0.0f;
if (fuse_activation == "hard_sigmoid") {
post_ops.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha,
fuse_beta);
post_ops.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else {
const std::unordered_map<std::string, dnnl::algorithm> activation_map = {
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
......@@ -1031,14 +1052,54 @@ static const dnnl::algorithm AcquireActivationAlgorithm(
{"swish", dnnl::algorithm::eltwise_swish},
{"tanh", dnnl::algorithm::eltwise_tanh}};
const auto& activation_type = activation_map.find(activation_name);
const auto& activation_type = activation_map.find(fuse_activation);
PADDLE_ENFORCE_NE(activation_type,
PADDLE_ENFORCE_NE(
activation_type,
activation_map.end(),
platform::errors::InvalidArgument(
"Activation '%s' not found in oneDNN algorithms mapper",
activation_name));
return activation_type->second;
fuse_activation));
post_ops.append_eltwise(
activation_scale, activation_type->second, fuse_alpha, fuse_beta);
}
}
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish")
attr_map.emplace("beta", "fuse_alpha");
else if (act_type == "relu6")
attr_map.emplace("threshold", "fuse_alpha");
else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
static std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
}
class ReorderMKLDNNHandler {
......
......@@ -23,8 +23,8 @@ from paddle.fluid.core import PassVersionChecker
class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
fuse_activation_alpha = None
fuse_activation_beta = None
fuse_alpha = None
fuse_beta = None
pass_name = 'softplus_activation_mkldnn_fuse_pass'
def setUp(self):
......@@ -34,13 +34,13 @@ class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
shape=[-1, 3, 100, 100],
dtype="float32")
softplus_out = fluid.layers.softplus(data)
if self.fuse_activation_beta is not None:
activation_out = self.fuse_activation(
softplus_out, self.fuse_activation_alpha,
self.fuse_activation_beta)
elif self.fuse_activation_alpha is not None:
activation_out = self.fuse_activation(
softplus_out, self.fuse_activation_alpha)
if self.fuse_beta is not None:
activation_out = self.fuse_activation(softplus_out,
self.fuse_alpha,
self.fuse_beta)
elif self.fuse_alpha is not None:
activation_out = self.fuse_activation(softplus_out,
self.fuse_alpha)
else:
activation_out = self.fuse_activation(softplus_out)
......@@ -73,7 +73,7 @@ class SoftplusActivationLeakyReluOneDNNFusePassTest(
def set_params(self):
self.fuse_activation = fluid.layers.leaky_relu
self.fuse_activation_alpha = 0.3
self.fuse_alpha = 0.3
class SoftplusActivationSwishOneDNNFusePassTest(
......@@ -81,7 +81,7 @@ class SoftplusActivationSwishOneDNNFusePassTest(
def set_params(self):
self.fuse_activation = fluid.layers.swish
self.fuse_activation_alpha = 3
self.fuse_alpha = 3
class SoftplusActivationHardSwishOneDNNFusePassTest(
......@@ -110,8 +110,8 @@ class SoftplusActivationClipOneDNNFusePassTest(
def set_params(self):
self.fuse_activation = fluid.layers.clip
self.fuse_activation_alpha = 1.1
self.fuse_activation_beta = 5.2
self.fuse_alpha = 1.1
self.fuse_beta = 5.2
class SoftplusActivationGeluErfOneDNNFusePassTest(
......@@ -126,7 +126,7 @@ class SoftplusActivationGeluTanhOneDNNFusePassTest(
def set_params(self):
self.fuse_activation = fluid.layers.gelu
self.fuse_activation_alpha = True # simulated "Approximate" attr
self.fuse_alpha = True # simulated "Approximate" attr
class SoftplusActivationRelu6OneDNNFusePassTest(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册