未验证 提交 a1b2e1e2 编写于 作者: S Sławomir Siwek 提交者: GitHub

Handle repetitive code in oneDNN activation fuse passes (#49824)

* extract fuse pass logic to header file

* adjust namespaces

* Update paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h

update date
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>

* add inline remove static
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>
上级 24379442
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_desc.h"
namespace paddle {
namespace framework {
namespace ir {
inline std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
}
inline std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha");
} else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha");
} else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
inline void SetActivationAttrs(paddle::framework::OpDesc* fused_op,
paddle::framework::OpDesc* act_op,
const std::string& act_type) {
if (fused_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(PADDLE_GET_CONST(bool, fused_op->GetAttr("use_mkldnn")),
phi::errors::PreconditionNotMet(
"oneDNN activation fuses require use_mkldnn=True"));
}
fused_op->SetAttr("use_mkldnn", true);
auto attr_map = GetAttributeMap(act_type);
for (const auto& attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fused_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
: "gelu_erf";
fused_op->SetAttr("fuse_activation", gelu_act_type);
} else {
fused_op->SetAttr("fuse_activation", act_type);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,8 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h"
namespace paddle {
......@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
std::vector<std::string> conv_types = {"fused_conv2d", "conv2d"};
for (auto& act_type : act_types) {
......@@ -40,7 +40,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
const std::string& conv_type,
std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -62,28 +62,13 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern);
OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op();
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
SetActivationAttrs(conv_op, activation->Op(), act_type);
if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type =
PADDLE_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
conv_op->SetAttr("fuse_alpha", 0.0f);
conv_op->SetAttr("fuse_beta", 0.0f);
}
conv_op->SetAttr("fuse_activation", act_type);
conv_op->SetOutput("Output", {activation_out->Name()});
IR_NODE_LINK_TO(conv, activation_out);
......@@ -105,7 +90,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
void ConvActivationMkldnnFusePass::FuseConvConcatAct(
Graph* graph, std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("conv2d_concat_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -137,13 +122,13 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
return;
}
bool is_not_conv_mkldnn =
bool is_not_conv_onednn =
!(prev_op_nodes[0]->Op()->GetAttrIfExists<bool>("use_mkldnn"));
if ((prev_op_nodes[0]->Op()->Type() != "conv2d" &&
prev_op_nodes[0]->Op()->Type() != "fused_conv2d") ||
is_not_conv_mkldnn) {
LOG(WARNING) << "This fuse pass supports only conv2d(mkldnn) | "
"fused_conv2d(mkldnn) + activation.";
is_not_conv_onednn) {
LOG(WARNING) << "This fuse pass supports only conv2d(oneDNN) | "
"fused_conv2d(oneDNN) + activation.";
return;
}
}
......@@ -153,23 +138,8 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}
OpDesc* act_op = activation_op->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
act_type = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
conv_op->SetAttr("fuse_alpha", 0.0f);
conv_op->SetAttr("fuse_beta", 0.0f);
}
conv_op->SetAttr("fuse_activation", act_type);
SetActivationAttrs(conv_op, activation_op->Op(), act_type);
}
concat_op->Op()->SetOutput("Out", {activation_out->Name()});
......
......@@ -15,8 +15,8 @@
#include "paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -27,7 +27,7 @@ namespace ir {
using string::PrettyLogDetail;
void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
std::vector<std::string> elt_types = {
"elementwise_add", "elementwise_sub", "elementwise_mul"};
......@@ -42,7 +42,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
const std::string &elt_type,
const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -62,35 +62,8 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, elementwise_act_pattern);
auto *elementwise_op = elementwise->Op();
if (elementwise_op->HasAttr("use_mkldnn")) {
const std::string wo_elt_type =
"The " + elt_type; // Workaround for PP error message checking.
PADDLE_ENFORCE_EQ(
PADDLE_GET_CONST(bool, elementwise_op->GetAttr("use_mkldnn")),
true,
platform::errors::PreconditionNotMet(
wo_elt_type + "+Act fusion may happen only when oneDNN library "
"is used."));
}
auto *activation_op = activation->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
elementwise_op->SetAttr(attr.second,
activation_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
PADDLE_GET_CONST(bool, activation_op->GetAttr("approximate")))
elementwise_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else
elementwise_op->SetAttr("fuse_activation", act_type);
elementwise_op->SetOutput("Out", {activation_out->Name()});
SetActivationAttrs(elementwise->Op(), activation->Op(), act_type);
elementwise->Op()->SetOutput("Out", {activation_out->Name()});
IR_OP_VAR_LINK(elementwise, activation_out);
GraphSafeRemoveNodes(g, {activation, elementwise_out});
......
......@@ -14,8 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h"
namespace paddle {
......@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
for (auto act_type : act_types) FuseFCAct(graph, act_type);
}
......@@ -33,7 +33,7 @@ void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -50,35 +50,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
GET_IR_NODE_FROM_SUBGRAPH(act, activation, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, activation_out, fc_act_pattern);
auto *fc_op = fc->Op();
auto *act_op = act->Op();
if (fc_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(
PADDLE_GET_CONST(bool, fc_op->GetAttr("use_mkldnn")),
platform::errors::PreconditionNotMet(
"The FC+Act fusion may happen only when oneDNN library "
"is used."));
}
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
: "gelu_erf";
fc_op->SetAttr("fuse_activation", gelu_act_type);
} else {
fc_op->SetAttr("fuse_activation", act_type);
}
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()});
SetActivationAttrs(fc->Op(), act->Op(), act_type);
fc->Op()->SetOutput("Out", {act_out->Name()});
IR_OP_VAR_LINK(fc, act_out);
GraphSafeRemoveNodes(g, {act, fc_out});
......
......@@ -14,8 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h"
namespace paddle {
......@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;
void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
auto matmul_types = {"matmul", "matmul_v2"};
for (const auto& matmul_type : matmul_types)
......@@ -37,7 +37,7 @@ void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
void MatmulActivationMkldnnFusePass::FuseMatmulAct(
Graph* graph, const std::string& matmul_type, std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(matmul_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -61,24 +61,8 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, matmul_act_pattern);
OpDesc* matmul_op = matmul->Op();
OpDesc* act_op = activation->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
matmul_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type =
PADDLE_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
}
matmul_op->SetAttr("fuse_activation", act_type);
matmul_op->SetOutput("Out", {activation_out->Name()});
SetActivationAttrs(matmul->Op(), activation->Op(), act_type);
matmul->Op()->SetOutput("Out", {activation_out->Name()});
IR_NODE_LINK_TO(matmul, activation_out);
GraphSafeRemoveNodes(graph, {activation, matmul_out});
......
......@@ -15,8 +15,8 @@
#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -27,7 +27,7 @@ namespace ir {
using string::PrettyLogDetail;
void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
// Currently softplus can't be fused with hard_sigmoid
act_types.erase(
......@@ -42,7 +42,7 @@ void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
Graph *graph, const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("softplus_activation", graph);
GraphPatternDetector gpd;
......@@ -63,34 +63,8 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
GET_IR_NODE_FROM_SUBGRAPH(
activation, activation, softplus_activation_pattern);
auto *softplus_op = softplus->Op();
if (softplus_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE_EQ(
PADDLE_GET_CONST(bool, softplus_op->GetAttr("use_mkldnn")),
true,
platform::errors::PreconditionNotMet("The softplus + activation "
"fusion may happen only when "
"oneDNN library is used."));
}
auto *activation_op = activation->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
PADDLE_GET_CONST(bool, activation_op->GetAttr("approximate")))
softplus_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else
softplus_op->SetAttr("fuse_activation", act_type);
softplus_op->SetAttr("use_mkldnn", true);
softplus_op->SetOutput("Out", {activation_out->Name()});
SetActivationAttrs(softplus->Op(), activation->Op(), act_type);
softplus->Op()->SetOutput("Out", {activation_out->Name()});
IR_OP_VAR_LINK(softplus, activation_out);
GraphSafeRemoveNodes(g, {activation, softplus_out});
......
......@@ -112,42 +112,6 @@ static void AppendActivation(const OneDNNContext& dev_ctx,
}
}
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha");
} else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha");
} else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
static std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
}
template <typename T,
typename TForward,
typename TBackward = onednn_dummy_primitive,
......@@ -1756,13 +1720,13 @@ static std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x,
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
paddle::platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank,
PADDLE_ENFORCE_EQ(
in_rank,
axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
phi::errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
......@@ -1771,7 +1735,7 @@ static std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x,
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
paddle::platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册