未验证 提交 c7b373f2 编写于 作者: Z zyfncg 提交者: GitHub

Clear extra attributes of activation op in OpMaker (#45772)

* clear extra attr of activation op in opmaker

* fix syntax bug

* fix mkldnn kernel

* fix merge conflict

* fix bug
上级 01888482
...@@ -213,8 +213,8 @@ class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase { ...@@ -213,8 +213,8 @@ class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
std::vector<std::unique_ptr<OpDesc>> retv; std::vector<std::unique_ptr<OpDesc>> retv;
retv.emplace_back(new OpDesc()); retv.emplace_back(new OpDesc());
try { try {
this->Apply(retv.front().get());
retv.front()->SetRuntimeAttrMap(this->RuntimeAttrs()); retv.front()->SetRuntimeAttrMap(this->RuntimeAttrs());
this->Apply(retv.front().get());
} catch (platform::EnforceNotMet& exception) { } catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(retv.front().get()->Type(), &exception); framework::AppendErrorOpHint(retv.front().get()->Type(), &exception);
throw std::move(exception); throw std::move(exception);
......
...@@ -38,29 +38,20 @@ static constexpr bool CanInplaceAct() { ...@@ -38,29 +38,20 @@ static constexpr bool CanInplaceAct() {
GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps; GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps;
} }
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ #define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \ class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \ : public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \ public: \
void Make() override { \ void Make() override { \
AddInput("X", \ AddInput("X", \
"Input of " #OP_NAME \ "Input of " #OP_NAME \
" operator, an N-D Tensor, with data type float32, " \ " operator, an N-D Tensor, with data type float32, " \
"float64 or float16."); \ "float64 or float16."); \
AddOutput("Out", \ AddOutput("Out", \
"Output of " #OP_NAME \ "Output of " #OP_NAME \
" operator, a Tensor with shape same as input."); \ " operator, a Tensor with shape same as input."); \
AddAttr<bool>("use_mkldnn", \ AddComment(OP_COMMENT); \
"(bool, default false) Only used in mkldnn kernel") \ } \
.SetDefault(false) \
.AsExtra(); \
AddAttr<bool>("use_cudnn", \
"(bool, default false) Only used in cudnn kernel, need " \
"install cudnn") \
.SetDefault(false) \
.AsExtra(); \
AddComment(OP_COMMENT); \
} \
} }
template <ActBwdOpFwdDeps kDepValue, typename T> template <ActBwdOpFwdDeps kDepValue, typename T>
...@@ -107,8 +98,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, ...@@ -107,8 +98,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
// } // }
// #endif // #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain &&
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
oper.CanMKLDNNBeUsed(ctx, data_type)) { oper.CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
...@@ -458,10 +448,6 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -458,10 +448,6 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
"A LoDTensor or Tensor with the same type and size as that of x."); "A LoDTensor or Tensor with the same type and size as that of x.");
AddAttr<float>("alpha", "Slope of the activation function at x < 0.") AddAttr<float>("alpha", "Slope of the activation function at x < 0.")
.SetDefault(0.02f); .SetDefault(0.02f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
LeakyRelu Activation Operator. LeakyRelu Activation Operator.
...@@ -483,35 +469,6 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -483,35 +469,6 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("beta", "The value of beta for Softplus.").SetDefault(1.0f); AddAttr<float>("beta", "The value of beta for Softplus.").SetDefault(1.0f);
AddAttr<float>("threshold", "The value of threshold for Softplus.") AddAttr<float>("threshold", "The value of threshold for Softplus.")
.SetDefault(20.0f); .SetDefault(20.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel.")
.SetDefault(false)
.AsExtra();
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn.")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"fuse_activation_type",
"Fused activation type used in softplus OneDNN kernel.")
.SetDefault("")
.AsExtra();
AddAttr<float>(
"fuse_activation_alpha",
"Fused activation alpha parameter type used in softplus OneDNN kernel.")
.SetDefault(0.0f)
.AsExtra();
AddAttr<float>(
"fuse_activation_beta",
"Fused activation beta parameter type used in softplus OneDNN kernel.")
.SetDefault(0.0f)
.AsExtra();
AddAttr<float>(
"fuse_activation_scale",
"Fused activation scale parameter type used in softplus OneDNN kernel.")
.SetDefault(1.0f)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
:strong:`Softplus Activation Operator` :strong:`Softplus Activation Operator`
...@@ -613,10 +570,6 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -613,10 +570,6 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
"The output is a multi-dimensional Tensor which has same " "The output is a multi-dimensional Tensor which has same "
"dimension and data type as the ``x``."); "dimension and data type as the ``x``.");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f); AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
ELU Activation Operator. ELU Activation Operator.
...@@ -712,10 +665,6 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -712,10 +665,6 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("threshold", AddAttr<float>("threshold",
"The threshold value of Relu6. Default is 6.0. ") "The threshold value of Relu6. Default is 6.0. ")
.SetDefault(6.0f); .SetDefault(6.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
Relu6 Activation Operator. Relu6 Activation Operator.
...@@ -817,10 +766,6 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -817,10 +766,6 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "Input of Swish operator"); AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator"); AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f); AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
Swish Activation Operator. Swish Activation Operator.
...@@ -841,10 +786,6 @@ class MishOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -841,10 +786,6 @@ class MishOpMaker : public framework::OpProtoAndCheckerMaker {
"of softplus will be used if absolute value of input is greater than " "of softplus will be used if absolute value of input is greater than "
":attr:`threshold`") ":attr:`threshold`")
.SetDefault(20.f); .SetDefault(20.f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
Mish Activation Operator. Mish Activation Operator.
...@@ -871,10 +812,6 @@ class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -871,10 +812,6 @@ class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(6.0f); .SetDefault(6.0f);
AddAttr<float>("offset", "The offset parameter of HardSwish operator") AddAttr<float>("offset", "The offset parameter of HardSwish operator")
.SetDefault(3.0f); .SetDefault(3.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
HardSwish Activation Operator. HardSwish Activation Operator.
......
...@@ -3,6 +3,11 @@ ...@@ -3,6 +3,11 @@
extra : extra :
attrs : [bool use_cudnn = false, bool use_mkldnn = false] attrs : [bool use_cudnn = false, bool use_mkldnn = false]
- api : acosh
backward : acosh_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : addmm - api : addmm
backward : addmm_grad backward : addmm_grad
extra : extra :
...@@ -18,12 +23,22 @@ ...@@ -18,12 +23,22 @@
extra : extra :
attrs : [bool use_cudnn = false, bool use_mkldnn = false] attrs : [bool use_cudnn = false, bool use_mkldnn = false]
- api : asinh
backward : asinh_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : atan2 - api : atan2
inputs : inputs :
{x : X1, y : X2} {x : X1, y : X2}
outputs : outputs :
out : Out out : Out
- api : atanh
backward : atanh_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : batch_norm - api : batch_norm
backward : batch_norm_grad backward : batch_norm_grad
extra : extra :
...@@ -45,6 +60,11 @@ ...@@ -45,6 +60,11 @@
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- api : ceil
backward : ceil_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : cholesky - api : cholesky
inputs : inputs :
x : X x : X
...@@ -107,6 +127,16 @@ ...@@ -107,6 +127,16 @@
extra : extra :
attrs : [bool use_cudnn = true, bool use_mkldnn = false, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()] attrs : [bool use_cudnn = true, bool use_mkldnn = false, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()]
- api : cos
backward : cos_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : cosh
backward : cosh_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : cross - api : cross
inputs : inputs :
{x : X, y : Y} {x : X, y : Y}
...@@ -179,6 +209,11 @@ ...@@ -179,6 +209,11 @@
extra : extra :
attrs : [bool fix_seed = false, int seed = 0] attrs : [bool fix_seed = false, int seed = 0]
- api : elu
backward : elu_grad
extra :
attrs : [bool use_mkldnn = false]
- api : erf - api : erf
inputs : inputs :
x : X x : X
...@@ -191,6 +226,16 @@ ...@@ -191,6 +226,16 @@
outputs : outputs :
out : Out out : Out
- api : exp
backward : exp_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : expm1
backward : expm1_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : fft_c2c - api : fft_c2c
inputs: {x: X} inputs: {x: X}
outputs: {out: Out} outputs: {out: Out}
...@@ -203,6 +248,11 @@ ...@@ -203,6 +248,11 @@
inputs: {x: X} inputs: {x: X}
outputs: {out: Out} outputs: {out: Out}
- api : floor
backward : floor_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : frobenius_norm - api : frobenius_norm
backward : frobenius_norm_grad backward : frobenius_norm_grad
extra : extra :
...@@ -223,6 +273,11 @@ ...@@ -223,6 +273,11 @@
extra : extra :
attrs : [bool is_test = false] attrs : [bool is_test = false]
- api : hard_swish
backward : hard_swish_grad
extra :
attrs : [bool use_mkldnn = false]
- api : inplace_abn - api : inplace_abn
backward : inplace_abn_grad backward : inplace_abn_grad
extra : extra :
...@@ -233,6 +288,11 @@ ...@@ -233,6 +288,11 @@
extra : extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
- api : leaky_relu
backward : leaky_relu_grad
extra :
attrs : [bool use_mkldnn = false]
- api : lgamma - api : lgamma
inputs : inputs :
x : X x : X
...@@ -244,11 +304,36 @@ ...@@ -244,11 +304,36 @@
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- api : log
backward : log_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : log10
backward : log10_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : log1p
backward : log1p_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : log2
backward : log2_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : log_softmax - api : log_softmax
backward : log_softmax_grad backward : log_softmax_grad
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- api : logsigmoid
backward : logsigmoid_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : lrn - api : lrn
backward : lrn_grad backward : lrn_grad
extra : extra :
...@@ -261,6 +346,11 @@ ...@@ -261,6 +346,11 @@
str mkldnn_data_type = "float32", 'int[] fused_reshape_X = {}', 'int[] fused_reshape_Y = {}', str mkldnn_data_type = "float32", 'int[] fused_reshape_X = {}', 'int[] fused_reshape_Y = {}',
'int[] fused_transpose_X = {}', 'int[] fused_transpose_Y = {}',] 'int[] fused_transpose_X = {}', 'int[] fused_transpose_Y = {}',]
- api : mish
backward : mish_grad
extra :
attrs : [bool use_mkldnn = false]
- api : mv - api : mv
inputs : inputs :
{x : X, vec : Vec} {x : X, vec : Vec}
...@@ -293,6 +383,21 @@ ...@@ -293,6 +383,21 @@
outputs : outputs :
out : Out out : Out
- api : prelu
backward : prelu_grad
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
- api : prelu
backward : prelu_grad
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
- api : reciprocal
backward : reciprocal_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : reduce_all - api : reduce_all
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
...@@ -336,15 +441,30 @@ ...@@ -336,15 +441,30 @@
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- api : relu
backward : relu_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : relu6
backward : relu6_grad
extra :
attrs : [bool use_mkldnn = false]
- api : renorm - api : renorm
backward : renorm_grad backward : renorm_grad
extra : extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : rnn - api : round
backward : rnn_grad backward : round_grad
extra : extra :
attrs : [bool is_test = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : rsqrt
backward : rsqrt_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : seed - api : seed
extra : extra :
...@@ -359,6 +479,26 @@ ...@@ -359,6 +479,26 @@
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- api : sigmoid
backward : sigmoid_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : silu
backward : silu_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : sin
backward : sin_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : sinh
backward : sinh_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : slice - api : slice
backward : slice_grad backward : slice_grad
extra : extra :
...@@ -368,10 +508,21 @@ ...@@ -368,10 +508,21 @@
backward : softmax_grad backward : softmax_grad
extra : extra :
attrs : [bool use_cudnn = false, bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] attrs : [bool use_cudnn = false, bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
- api : prelu
backward : prelu_grad - api : softplus
backward : softplus_grad
extra : extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false, str fuse_activation_type = "", float fuse_activation_alpha = 0.0f,
float fuse_activation_beta = 0.0f, float fuse_activation_scale = 1.0f]
- api : softsign
backward : softsign_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : rnn
backward : rnn_grad
extra :
attrs : [bool is_test = false]
- api : solve - api : solve
inputs : inputs :
...@@ -379,6 +530,16 @@ ...@@ -379,6 +530,16 @@
outputs : outputs :
out : Out out : Out
- api : sqrt
backward : sqrt_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : square
backward : square_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : squeeze (squeeze2) - api : squeeze (squeeze2)
backward : squeeze_grad (squeeze2_grad) backward : squeeze_grad (squeeze2_grad)
extra : extra :
...@@ -389,11 +550,31 @@ ...@@ -389,11 +550,31 @@
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- api : swish
backward : swish_grad
extra :
attrs : [bool use_mkldnn = false]
- api : sync_batch_norm - api : sync_batch_norm
backward : sync_batch_norm_grad backward : sync_batch_norm_grad
extra : extra :
attrs : [bool use_mkldnn = false, bool fuse_with_relu = false] attrs : [bool use_mkldnn = false, bool fuse_with_relu = false]
- api : tan
backward : tan_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : tanh
backward : tanh_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : tanh_shrink
backward : tanh_shrink_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- api : trace - api : trace
inputs : inputs :
x : Input x : Input
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册