From a1b2e1e2b73d5ce3ec38315a329385f60aa92b68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Wed, 18 Jan 2023 15:37:27 +0100 Subject: [PATCH] Handle repetitive code in oneDNN activation fuse passes (#49824) * extract fuse pass logic to header file * adjust namespaces * Update paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h update date Co-authored-by: Tomasz Socha * add inline remove static Co-authored-by: Tomasz Socha --- .../ir/mkldnn/activation_onednn_fuse_pass.h | 87 +++++++++++++++++++ .../conv_activation_mkldnn_fuse_pass.cc | 50 +++-------- .../ir/mkldnn/elt_act_mkldnn_fuse_pass.cc | 37 ++------ .../ir/mkldnn/fc_act_mkldnn_fuse_pass.cc | 37 ++------ .../matmul_activation_mkldnn_fuse_pass.cc | 26 ++---- .../softplus_activation_mkldnn_fuse_pass.cc | 36 ++------ paddle/phi/backends/onednn/onednn_reuse.h | 58 +++---------- 7 files changed, 128 insertions(+), 203 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h diff --git a/paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h new file mode 100644 index 00000000000..64232a2b725 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h @@ -0,0 +1,87 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/op_desc.h" + +namespace paddle { +namespace framework { +namespace ir { + +inline std::vector GetSupportedActivations() { + return std::vector{"abs", + "clip", + "gelu", + "hard_sigmoid", + "hard_swish", + "leaky_relu", + "mish", + "relu", + "relu6", + "sigmoid", + "sqrt", + "swish", + "tanh"}; +} + +inline std::unordered_map GetAttributeMap( + std::string act_type) { + std::unordered_map attr_map; + if (act_type == "swish") { + attr_map.emplace("beta", "fuse_alpha"); + } else if (act_type == "relu6") { + attr_map.emplace("threshold", "fuse_alpha"); + } else if (act_type == "hard_sigmoid") { + attr_map.emplace("slope", "fuse_alpha"); + attr_map.emplace("offset", "fuse_beta"); + } else if (act_type == "clip") { + attr_map.emplace("min", "fuse_alpha"); + attr_map.emplace("max", "fuse_beta"); + } else { + attr_map.emplace("alpha", "fuse_alpha"); + attr_map.emplace("beta", "fuse_beta"); + } + return attr_map; +} + +inline void SetActivationAttrs(paddle::framework::OpDesc* fused_op, + paddle::framework::OpDesc* act_op, + const std::string& act_type) { + if (fused_op->HasAttr("use_mkldnn")) { + PADDLE_ENFORCE(PADDLE_GET_CONST(bool, fused_op->GetAttr("use_mkldnn")), + phi::errors::PreconditionNotMet( + "oneDNN activation fuses require use_mkldnn=True")); + } + fused_op->SetAttr("use_mkldnn", true); + + auto attr_map = GetAttributeMap(act_type); + for (const auto& attr : attr_map) { + if (act_op->HasAttr(attr.first)) { + fused_op->SetAttr(attr.second, act_op->GetAttr(attr.first)); + } + } + + if (act_type == "gelu" && act_op->HasAttr("approximate")) { + std::string gelu_act_type = + PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh" + : "gelu_erf"; + fused_op->SetAttr("fuse_activation", gelu_act_type); + } else { + fused_op->SetAttr("fuse_activation", act_type); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 2db957d84da..f905df3e53c 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -14,8 +14,8 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/utils/string/pretty_log.h" namespace paddle { @@ -25,7 +25,7 @@ namespace ir { using string::PrettyLogDetail; void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { - auto act_types = phi::funcs::GetSupportedActivations(); + auto act_types = GetSupportedActivations(); std::vector conv_types = {"fused_conv2d", "conv2d"}; for (auto& act_type : act_types) { @@ -40,7 +40,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, const std::string& conv_type, std::string& act_type) const { PADDLE_ENFORCE_NOT_NULL( - graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + graph, phi::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; @@ -62,28 +62,13 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern); OpDesc* conv_op = conv->Op(); - OpDesc* act_op = activation->Op(); if (conv_op->Type() == "conv2d") { conv_op->SetType("fused_conv2d"); } - auto attr_map = phi::funcs::GetAttributeMap(act_type); - for (const auto& attrs : attr_map) { - if (act_op->HasAttr(attrs.first)) { - conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first)); - } - } + SetActivationAttrs(conv_op, activation->Op(), act_type); - if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) { - act_type = - PADDLE_GET_CONST(bool, activation->Op()->GetAttr("approximate")) - ? "gelu_tanh" - : "gelu_erf"; - conv_op->SetAttr("fuse_alpha", 0.0f); - conv_op->SetAttr("fuse_beta", 0.0f); - } - conv_op->SetAttr("fuse_activation", act_type); conv_op->SetOutput("Output", {activation_out->Name()}); IR_NODE_LINK_TO(conv, activation_out); @@ -105,7 +90,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, void ConvActivationMkldnnFusePass::FuseConvConcatAct( Graph* graph, std::string& act_type) const { PADDLE_ENFORCE_NOT_NULL( - graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + graph, phi::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init("conv2d_concat_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; @@ -137,13 +122,13 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct( return; } - bool is_not_conv_mkldnn = + bool is_not_conv_onednn = !(prev_op_nodes[0]->Op()->GetAttrIfExists("use_mkldnn")); if ((prev_op_nodes[0]->Op()->Type() != "conv2d" && prev_op_nodes[0]->Op()->Type() != "fused_conv2d") || - is_not_conv_mkldnn) { - LOG(WARNING) << "This fuse pass supports only conv2d(mkldnn) | " - "fused_conv2d(mkldnn) + activation."; + is_not_conv_onednn) { + LOG(WARNING) << "This fuse pass supports only conv2d(oneDNN) | " + "fused_conv2d(oneDNN) + activation."; return; } } @@ -153,23 +138,8 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct( if (conv_op->Type() == "conv2d") { conv_op->SetType("fused_conv2d"); } - OpDesc* act_op = activation_op->Op(); - auto attr_map = phi::funcs::GetAttributeMap(act_type); - for (const auto& attrs : attr_map) { - if (act_op->HasAttr(attrs.first)) { - conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first)); - } - } - - if (act_type == "gelu" && act_op->HasAttr("approximate")) { - act_type = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) - ? "gelu_tanh" - : "gelu_erf"; - conv_op->SetAttr("fuse_alpha", 0.0f); - conv_op->SetAttr("fuse_beta", 0.0f); - } - conv_op->SetAttr("fuse_activation", act_type); + SetActivationAttrs(conv_op, activation_op->Op(), act_type); } concat_op->Op()->SetOutput("Out", {activation_out->Name()}); diff --git a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc index 618b6993729..4b3f6a95d6d 100644 --- a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc @@ -15,8 +15,8 @@ #include "paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/enforce.h" #include "paddle/utils/string/pretty_log.h" @@ -27,7 +27,7 @@ namespace ir { using string::PrettyLogDetail; void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const { - auto act_types = phi::funcs::GetSupportedActivations(); + auto act_types = GetSupportedActivations(); std::vector elt_types = { "elementwise_add", "elementwise_sub", "elementwise_mul"}; @@ -42,7 +42,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( const std::string &elt_type, const std::string &act_type) const { PADDLE_ENFORCE_NOT_NULL( - graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + graph, phi::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; @@ -62,35 +62,8 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( GET_IR_NODE_FROM_SUBGRAPH( activation_out, activation_out, elementwise_act_pattern); - auto *elementwise_op = elementwise->Op(); - - if (elementwise_op->HasAttr("use_mkldnn")) { - const std::string wo_elt_type = - "The " + elt_type; // Workaround for PP error message checking. - PADDLE_ENFORCE_EQ( - PADDLE_GET_CONST(bool, elementwise_op->GetAttr("use_mkldnn")), - true, - platform::errors::PreconditionNotMet( - wo_elt_type + "+Act fusion may happen only when oneDNN library " - "is used.")); - } - - auto *activation_op = activation->Op(); - auto attr_map = phi::funcs::GetAttributeMap(act_type); - for (const auto &attr : attr_map) { - if (activation_op->HasAttr(attr.first)) { - elementwise_op->SetAttr(attr.second, - activation_op->GetAttr(attr.first)); - } - } - - if (act_type == "gelu" && activation_op->HasAttr("approximate") && - PADDLE_GET_CONST(bool, activation_op->GetAttr("approximate"))) - elementwise_op->SetAttr("fuse_activation", std::string("gelu_tanh")); - else - elementwise_op->SetAttr("fuse_activation", act_type); - - elementwise_op->SetOutput("Out", {activation_out->Name()}); + SetActivationAttrs(elementwise->Op(), activation->Op(), act_type); + elementwise->Op()->SetOutput("Out", {activation_out->Name()}); IR_OP_VAR_LINK(elementwise, activation_out); GraphSafeRemoveNodes(g, {activation, elementwise_out}); diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc index 60ab407f00c..d007ef16d33 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc @@ -14,8 +14,8 @@ #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/utils/string/pretty_log.h" namespace paddle { @@ -25,7 +25,7 @@ namespace ir { using string::PrettyLogDetail; void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { - auto act_types = phi::funcs::GetSupportedActivations(); + auto act_types = GetSupportedActivations(); for (auto act_type : act_types) FuseFCAct(graph, act_type); } @@ -33,7 +33,7 @@ void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, const std::string &act_type) const { PADDLE_ENFORCE_NOT_NULL( - graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + graph, phi::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; @@ -50,35 +50,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, GET_IR_NODE_FROM_SUBGRAPH(act, activation, fc_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(act_out, activation_out, fc_act_pattern); - auto *fc_op = fc->Op(); - auto *act_op = act->Op(); - - if (fc_op->HasAttr("use_mkldnn")) { - PADDLE_ENFORCE( - PADDLE_GET_CONST(bool, fc_op->GetAttr("use_mkldnn")), - platform::errors::PreconditionNotMet( - "The FC+Act fusion may happen only when oneDNN library " - "is used.")); - } - - auto attr_map = phi::funcs::GetAttributeMap(act_type); - for (const auto &attr : attr_map) { - if (act_op->HasAttr(attr.first)) { - fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first)); - } - } - - if (act_type == "gelu" && act_op->HasAttr("approximate")) { - std::string gelu_act_type = - PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh" - : "gelu_erf"; - fc_op->SetAttr("fuse_activation", gelu_act_type); - } else { - fc_op->SetAttr("fuse_activation", act_type); - } - - fc_op->SetAttr("use_mkldnn", true); - fc_op->SetOutput("Out", {act_out->Name()}); + SetActivationAttrs(fc->Op(), act->Op(), act_type); + fc->Op()->SetOutput("Out", {act_out->Name()}); IR_OP_VAR_LINK(fc, act_out); GraphSafeRemoveNodes(g, {act, fc_out}); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc index 07a608c5a2b..50db74e46d1 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc @@ -14,8 +14,8 @@ #include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/utils/string/pretty_log.h" namespace paddle { @@ -25,7 +25,7 @@ namespace ir { using string::PrettyLogDetail; void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { - auto act_types = phi::funcs::GetSupportedActivations(); + auto act_types = GetSupportedActivations(); auto matmul_types = {"matmul", "matmul_v2"}; for (const auto& matmul_type : matmul_types) @@ -37,7 +37,7 @@ void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { void MatmulActivationMkldnnFusePass::FuseMatmulAct( Graph* graph, const std::string& matmul_type, std::string& act_type) const { PADDLE_ENFORCE_NOT_NULL( - graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + graph, phi::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init(matmul_type + "_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; @@ -61,24 +61,8 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct( GET_IR_NODE_FROM_SUBGRAPH( activation_out, activation_out, matmul_act_pattern); - OpDesc* matmul_op = matmul->Op(); - OpDesc* act_op = activation->Op(); - - auto attr_map = phi::funcs::GetAttributeMap(act_type); - for (const auto& attrs : attr_map) { - if (act_op->HasAttr(attrs.first)) { - matmul_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first)); - } - } - - if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) { - act_type = - PADDLE_GET_CONST(bool, activation->Op()->GetAttr("approximate")) - ? "gelu_tanh" - : "gelu_erf"; - } - matmul_op->SetAttr("fuse_activation", act_type); - matmul_op->SetOutput("Out", {activation_out->Name()}); + SetActivationAttrs(matmul->Op(), activation->Op(), act_type); + matmul->Op()->SetOutput("Out", {activation_out->Name()}); IR_NODE_LINK_TO(matmul, activation_out); GraphSafeRemoveNodes(graph, {activation, matmul_out}); diff --git a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc index a4e74bb376d..0954414bee1 100644 --- a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc @@ -15,8 +15,8 @@ #include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/enforce.h" #include "paddle/utils/string/pretty_log.h" @@ -27,7 +27,7 @@ namespace ir { using string::PrettyLogDetail; void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const { - auto act_types = phi::funcs::GetSupportedActivations(); + auto act_types = GetSupportedActivations(); // Currently softplus can't be fused with hard_sigmoid act_types.erase( @@ -42,7 +42,7 @@ void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const { void SoftplusActivationOneDNNPass::FuseSoftplusActivation( Graph *graph, const std::string &act_type) const { PADDLE_ENFORCE_NOT_NULL( - graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + graph, phi::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init("softplus_activation", graph); GraphPatternDetector gpd; @@ -63,34 +63,8 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( GET_IR_NODE_FROM_SUBGRAPH( activation, activation, softplus_activation_pattern); - auto *softplus_op = softplus->Op(); - - if (softplus_op->HasAttr("use_mkldnn")) { - PADDLE_ENFORCE_EQ( - PADDLE_GET_CONST(bool, softplus_op->GetAttr("use_mkldnn")), - true, - platform::errors::PreconditionNotMet("The softplus + activation " - "fusion may happen only when " - "oneDNN library is used.")); - } - - auto *activation_op = activation->Op(); - auto attr_map = phi::funcs::GetAttributeMap(act_type); - for (const auto &attr : attr_map) { - if (activation_op->HasAttr(attr.first)) { - softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first)); - } - } - - if (act_type == "gelu" && activation_op->HasAttr("approximate") && - PADDLE_GET_CONST(bool, activation_op->GetAttr("approximate"))) - softplus_op->SetAttr("fuse_activation", std::string("gelu_tanh")); - else - softplus_op->SetAttr("fuse_activation", act_type); - - softplus_op->SetAttr("use_mkldnn", true); - - softplus_op->SetOutput("Out", {activation_out->Name()}); + SetActivationAttrs(softplus->Op(), activation->Op(), act_type); + softplus->Op()->SetOutput("Out", {activation_out->Name()}); IR_OP_VAR_LINK(softplus, activation_out); GraphSafeRemoveNodes(g, {activation, softplus_out}); diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index dbb70cb07aa..c398138e2d5 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -112,42 +112,6 @@ static void AppendActivation(const OneDNNContext& dev_ctx, } } -static std::unordered_map GetAttributeMap( - std::string act_type) { - std::unordered_map attr_map; - if (act_type == "swish") { - attr_map.emplace("beta", "fuse_alpha"); - } else if (act_type == "relu6") { - attr_map.emplace("threshold", "fuse_alpha"); - } else if (act_type == "hard_sigmoid") { - attr_map.emplace("slope", "fuse_alpha"); - attr_map.emplace("offset", "fuse_beta"); - } else if (act_type == "clip") { - attr_map.emplace("min", "fuse_alpha"); - attr_map.emplace("max", "fuse_beta"); - } else { - attr_map.emplace("alpha", "fuse_alpha"); - attr_map.emplace("beta", "fuse_beta"); - } - return attr_map; -} - -static std::vector GetSupportedActivations() { - return std::vector{"abs", - "clip", - "gelu", - "hard_sigmoid", - "hard_swish", - "leaky_relu", - "mish", - "relu", - "relu6", - "sigmoid", - "sqrt", - "swish", - "tanh"}; -} - template TransposeAxis(const std::vector& x, auto axis_set = std::set(axis.begin(), axis.end()); PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "In an axis array, elements must be unique.")); - PADDLE_ENFORCE_EQ(in_rank, - axis_size, - paddle::platform::errors::InvalidArgument( - "The input dimension's size " - "should be equal to the axis's size. " - "But received dimension is %d, " - "axis's size is %d", - in_rank, - axis_size)); + PADDLE_ENFORCE_EQ( + in_rank, + axis_size, + phi::errors::InvalidArgument("The input dimension's size " + "should be equal to the axis's size. " + "But received dimension is %d, " + "axis's size is %d", + in_rank, + axis_size)); PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis values must be ranging from 0 to (dims - 1).")); std::vector new_x(x.size()); -- GitLab