diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 1e1d3cf7f7634e2e5a433025f175202bd6c4b40e..7ae4d2f6b6c0b0f30c06adc34c811bfe34b59fa6 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -206,120 +206,57 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; + REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL(sigmoid, - ops::ActivationKernel>); -REGISTER_OP_CPU_KERNEL( - sigmoid_grad, ops::ActivationGradKernel>); REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL( - exp, - ops::ActivationKernel); -REGISTER_OP_CPU_KERNEL(exp_grad, - ops::ActivationGradKernel); REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL(relu, - ops::ActivationKernel>); -REGISTER_OP_CPU_KERNEL( - relu_grad, ops::ActivationGradKernel>); REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL( - tanh, - ops::ActivationKernel); -REGISTER_OP_CPU_KERNEL( - tanh_grad, ops::ActivationGradKernel>); REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL( - sqrt, - ops::ActivationKernel); -REGISTER_OP_CPU_KERNEL( - sqrt_grad, ops::ActivationGradKernel>); REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL( - abs, - ops::ActivationKernel); -REGISTER_OP_CPU_KERNEL(abs_grad, - ops::ActivationGradKernel); REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker, reciprocal_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL(reciprocal, - ops::ActivationKernel>); -REGISTER_OP_CPU_KERNEL( - reciprocal_grad, - ops::ActivationGradKernel>); REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL( - log, - ops::ActivationKernel); -REGISTER_OP_CPU_KERNEL( - log_grad, ops::ActivationGradKernel>); REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL(square, - ops::ActivationKernel); -REGISTER_OP_CPU_KERNEL( - square_grad, ops::ActivationGradKernel>); REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL(softsign, - ops::ActivationKernel>); -REGISTER_OP_CPU_KERNEL( - softsign_grad, ops::ActivationGradKernel>); REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker, brelu_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL(brelu, - ops::BReluKernel); -REGISTER_OP_CPU_KERNEL(brelu_grad, - ops::BReluGradKernel); REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, soft_relu_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL(soft_relu, - ops::SoftReluKernel); -REGISTER_OP_CPU_KERNEL( - soft_relu_grad, ops::SoftReluGradKernel); REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL(pow, ops::PowKernel); -REGISTER_OP_CPU_KERNEL(pow_grad, - ops::PowGradKernel); REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, ops::ActivationOpGrad); -REGISTER_OP_CPU_KERNEL(stanh, - ops::STanhKernel); -REGISTER_OP_CPU_KERNEL(stanh_grad, - ops::STanhGradKernel); + +#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ + REGISTER_OP_CPU_KERNEL( \ + act_type, \ + paddle::operators::ActivationKernel>); \ + REGISTER_OP_CPU_KERNEL(act_type##_grad, \ + paddle::operators::ActivationGradKernel< \ + paddle::platform::CPUPlace, \ + paddle::operators::grad_functor>); + +FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu index 56886d8b1b93a19e9a01798ef79e89f9b5d6fca1..93e9f1c694bacba48c4f8c46f90fb5b512bead99 100644 --- a/paddle/operators/activation_op.cu +++ b/paddle/operators/activation_op.cu @@ -15,93 +15,14 @@ #define EIGEN_USE_GPU #include "paddle/operators/activation_op.h" -namespace ops = paddle::operators; - -REGISTER_OP_GPU_KERNEL(sigmoid, - ops::ActivationKernel>); -REGISTER_OP_GPU_KERNEL( - sigmoid_grad, ops::ActivationGradKernel>); - -REGISTER_OP_GPU_KERNEL( - exp, - ops::ActivationKernel); -REGISTER_OP_GPU_KERNEL(exp_grad, - ops::ActivationGradKernel); -REGISTER_OP_GPU_KERNEL(relu, - ops::ActivationKernel>); -REGISTER_OP_GPU_KERNEL( - relu_grad, ops::ActivationGradKernel>); - -REGISTER_OP_GPU_KERNEL( - tanh, - ops::ActivationKernel); -REGISTER_OP_GPU_KERNEL( - tanh_grad, ops::ActivationGradKernel>); - -REGISTER_OP_GPU_KERNEL( - sqrt, - ops::ActivationKernel); -REGISTER_OP_GPU_KERNEL( - sqrt_grad, ops::ActivationGradKernel>); - -REGISTER_OP_GPU_KERNEL( - abs, - ops::ActivationKernel); -REGISTER_OP_GPU_KERNEL(abs_grad, - ops::ActivationGradKernel); - -REGISTER_OP_GPU_KERNEL(reciprocal, - ops::ActivationKernel>); -REGISTER_OP_GPU_KERNEL( - reciprocal_grad, - ops::ActivationGradKernel>); - -REGISTER_OP_GPU_KERNEL( - log, - ops::ActivationKernel); -REGISTER_OP_GPU_KERNEL( - log_grad, ops::ActivationGradKernel>); - -REGISTER_OP_GPU_KERNEL(square, - ops::ActivationKernel); -REGISTER_OP_GPU_KERNEL( - square_grad, ops::ActivationGradKernel>); - -REGISTER_OP_GPU_KERNEL(softsign, - ops::ActivationKernel>); -REGISTER_OP_GPU_KERNEL( - softsign_grad, ops::ActivationGradKernel>); - -REGISTER_OP_GPU_KERNEL(brelu, - ops::BReluKernel); -REGISTER_OP_GPU_KERNEL(brelu_grad, - ops::BReluGradKernel); - -REGISTER_OP_GPU_KERNEL(soft_relu, - ops::SoftReluKernel); -REGISTER_OP_GPU_KERNEL( - soft_relu_grad, ops::SoftReluGradKernel); - -REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel); -REGISTER_OP_GPU_KERNEL(pow_grad, - ops::PowGradKernel); - -REGISTER_OP_GPU_KERNEL(stanh, - ops::STanhKernel); -REGISTER_OP_GPU_KERNEL(stanh_grad, - ops::STanhGradKernel); +#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \ + REGISTER_OP_GPU_KERNEL( \ + act_type, \ + paddle::operators::ActivationKernel>); \ + REGISTER_OP_GPU_KERNEL(act_type##_grad, \ + paddle::operators::ActivationGradKernel< \ + paddle::platform::GPUPlace, \ + paddle::operators::grad_functor>); + +FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index b9f52e1af3958b247e4854389cb467e2fce25e27..ff35c2d97e856ab76581c74512a0b451ea6fe60c 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -19,9 +19,12 @@ namespace paddle { namespace operators { -template -class ActivationKernel : public framework::OpKernel { +template +class ActivationKernel + : public framework::OpKernel { public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Y = context.Output("Y"); @@ -31,13 +34,20 @@ class ActivationKernel : public framework::OpKernel { auto y = framework::EigenVector::Flatten(*Y); auto place = context.GetEigenDevice(); Functor functor; + + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = context.Attr(attr.first); + } functor(place, x, y); } }; -template -class ActivationGradKernel : public framework::OpKernel { +template +class ActivationGradKernel + : public framework::OpKernel { public: + using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Y = context.Input("Y"); @@ -51,159 +61,210 @@ class ActivationGradKernel : public framework::OpKernel { auto dx = framework::EigenVector::Flatten(*dX); auto place = context.GetEigenDevice(); Functor functor; + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = context.Attr(attr.first); + } functor(place, x, y, dy, dx); } }; +template +struct BaseActivationFunctor { + using ELEMENT_TYPE = T; + + using AttrPair = std::vector>; + + AttrPair GetAttrs() { return AttrPair(); } +}; + // sigmoid(x) = 1 / (1 + exp(-x)) template -struct SigmoidFunctor { +struct SigmoidFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y) { + void operator()(Device d, X x, Y y) const { y.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); } }; template -struct SigmoidGradFunctor { +struct SigmoidGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * y * (static_cast(1) - y); } }; // exp(x) = e^x -struct ExpFunctor { +template +struct ExpFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y) { + void operator()(Device d, X x, Y y) const { y.device(d) = x.exp(); } }; -struct ExpGradFunctor { +template +struct ExpGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * y; } }; // relu(x) = max(x, 0) template -struct ReluFunctor { +struct ReluFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y) { + void operator()(Device d, X x, Y y) const { y.device(d) = x.cwiseMax(static_cast(0)); } }; template -struct ReluGradFunctor { +struct ReluGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * (x > static_cast(0)).template cast(); } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) -struct TanhFunctor { +template +struct TanhFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y) { + void operator()(Device d, X x, Y y) const { y.device(d) = x.tanh(); } }; template -struct TanhGradFunctor { +struct TanhGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * (static_cast(1) - y * y); } }; // sqrt(x) = x^(1/2) -struct SqrtFunctor { +template +struct SqrtFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y) { + void operator()(Device d, X x, Y y) const { y.device(d) = x.sqrt(); } }; template -struct SqrtGradFunctor { +struct SqrtGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + void operator()(Device d, X x, Y y, dY dy, dX dx) const { const Y y_conj = Eigen::numext::conj(y); dx.device(d) = static_cast(0.5) * dy / y_conj; } }; // abs(x) = |x| -struct AbsFunctor { +template +struct AbsFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y) { + void operator()(Device d, X x, Y y) const { y.device(d) = x.abs(); } }; -struct AbsGradFunctor { +template +struct AbsGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * x.sign(); } }; // reciprocal(x) = 1 / x template -struct ReciprocalFunctor { +struct ReciprocalFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y) { + void operator()(Device d, X x, Y y) const { y.device(d) = static_cast(1) / x; } }; template -struct ReciprocalGradFunctor { +struct ReciprocalGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * static_cast(-1) * y * y; } }; // log(x) = natural logarithm of x -struct LogFunctor { +template +struct LogFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y) { + void operator()(Device d, X x, Y y) const { y.device(d) = x.log(); } }; template -struct LogGradFunctor { +struct LogGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * (static_cast(1) / x); } }; // square(x) = x^2 -struct SquareFunctor { +template +struct SquareFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y) { + void operator()(Device d, X x, Y y) const { y.device(d) = x.square(); } }; template -struct SquareGradFunctor { +struct SquareGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * static_cast(2) * x; } }; +template +struct BReluFunctor : public BaseActivationFunctor { + float t_min; + float t_max; + + // NOTE: Explicit hides the `BaseActivationFunctor::GetAttrs` + // not polymorphism for speed. + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"t_min", &t_min}, {"t_max", &t_max}}; + } + + template + void operator()(Device d, X x, Y y) const { + y.device(d) = x.cwiseMax(t_min).cwiseMin(t_max); + } +}; + +template +struct BReluGradFunctor : public BaseActivationFunctor { + float t_min; + float t_max; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"t_min", &t_min}, {"t_max", &t_max}}; + } + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = dy * ((x > t_min) * (x < t_max)).template cast(); + } +}; + // softsign(x) = x / (1 + |x|) template -struct SoftsignFunctor { +struct SoftsignFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) { y.device(d) = x / (static_cast(1) + x.abs()); @@ -213,7 +274,7 @@ struct SoftsignFunctor { // d(softsign(x))/dx = 1 / (1 + |x|)^2 // Taken from https://en.wikipedia.org/wiki/Activation_function template -struct SoftsignGradFunctor { +struct SoftsignGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { dx.device(d) = @@ -221,153 +282,101 @@ struct SoftsignGradFunctor { } }; -template -class BReluKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Y = context.Output("Y"); - auto t_min = static_cast(context.Attr("t_min")); - auto t_max = static_cast(context.Attr("t_max")); - Y->mutable_data(context.GetPlace()); - - auto x = framework::EigenVector::Flatten(*X); - auto y = framework::EigenVector::Flatten(*Y); - auto place = context.GetEigenDevice(); - y.device(place) = x.cwiseMax(t_min).cwiseMin(t_max); +template +struct SoftReluFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; } -}; -template -class BReluGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* dY = context.Input(framework::GradVarName("Y")); - auto* dX = context.Output(framework::GradVarName("X")); - auto t_min = static_cast(context.Attr("t_min")); - auto t_max = static_cast(context.Attr("t_max")); - dX->mutable_data(context.GetPlace()); - - auto dy = framework::EigenVector::Flatten(*dY); - auto x = framework::EigenVector::Flatten(*X); - auto dx = framework::EigenVector::Flatten(*dX); - auto place = context.GetEigenDevice(); - - dx.device(place) = dy * ((x > t_min) * (x < t_max)).template cast(); + template + void operator()(Device d, X x, Y y) const { + auto temp = x.cwiseMax(-threshold).cwiseMin(threshold); + y.device(d) = (static_cast(1) + temp.exp()).log(); } }; -template -class SoftReluKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Y = context.Output("Y"); - auto threshold = static_cast(context.Attr("threshold")); - Y->mutable_data(context.GetPlace()); - - auto x = framework::EigenVector::Flatten(*X); - auto y = framework::EigenVector::Flatten(*Y); - auto place = context.GetEigenDevice(); - auto temp = x.cwiseMax(-threshold).cwiseMin(threshold).eval(); - y.device(place) = (static_cast(1) + temp.exp()).log(); +template +struct SoftReluGradFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; } -}; - -template -class SoftReluGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Y = context.Input("Y"); - auto* dY = context.Input(framework::GradVarName("Y")); - auto* dX = context.Output(framework::GradVarName("X")); - auto threshold = static_cast(context.Attr("threshold")); - dX->mutable_data(context.GetPlace()); - - auto x = framework::EigenVector::Flatten(*X); - auto y = framework::EigenVector::Flatten(*Y); - auto dy = framework::EigenVector::Flatten(*dY); - auto dx = framework::EigenVector::Flatten(*dX); - auto place = context.GetEigenDevice(); + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { auto temp = ((x > -threshold) * (x < threshold)).template cast().eval(); - dx.device(place) = dy * (static_cast(1) - (-y).exp()) * temp; + dx.device(d) = dy * (static_cast(1) - (-y).exp()) * temp; } }; -template -class PowKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Y = context.Output("Y"); - auto factor = static_cast(context.Attr("factor")); - Y->mutable_data(context.GetPlace()); - - auto x = framework::EigenVector::Flatten(*X); - auto y = framework::EigenVector::Flatten(*Y); - auto place = context.GetEigenDevice(); - y.device(place) = x.pow(factor); +template +struct PowFunctor : public BaseActivationFunctor { + float factor; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"factor", &factor}}; + } + template + void operator()(Device d, X x, Y y) const { + y.device(d) = x.pow(factor); } }; -template -class PowGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* dY = context.Input(framework::GradVarName("Y")); - auto* dX = context.Output(framework::GradVarName("X")); - auto factor = static_cast(context.Attr("factor")); - dX->mutable_data(context.GetPlace()); - - auto dy = framework::EigenVector::Flatten(*dY); - auto x = framework::EigenVector::Flatten(*X); - auto dx = framework::EigenVector::Flatten(*dX); - auto place = context.GetEigenDevice(); - - dx.device(place) = dy * factor * x.pow(factor - static_cast(1)); +template +struct PowGradFunctor : public BaseActivationFunctor { + float factor; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"factor", &factor}}; + } + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = dy * factor * x.pow(factor - static_cast(1)); } }; -template -class STanhKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Y = context.Output("Y"); - auto scale_a = static_cast(context.Attr("scale_a")); - auto scale_b = static_cast(context.Attr("scale_b")); - Y->mutable_data(context.GetPlace()); +template +struct STanhFunctor : public BaseActivationFunctor { + float scale_a; + float scale_b; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } - auto x = framework::EigenVector::Flatten(*X); - auto y = framework::EigenVector::Flatten(*Y); - auto place = context.GetEigenDevice(); - y.device(place) = scale_b * (scale_a * x).tanh(); + template + void operator()(Device d, X x, Y y) const { + y.device(d) = scale_b * (scale_a * x).tanh(); } }; -template -class STanhGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* dY = context.Input(framework::GradVarName("Y")); - auto* dX = context.Output(framework::GradVarName("X")); - auto scale_a = static_cast(context.Attr("scale_a")); - auto scale_b = static_cast(context.Attr("scale_b")); - dX->mutable_data(context.GetPlace()); - - auto dy = framework::EigenVector::Flatten(*dY); - auto x = framework::EigenVector::Flatten(*X); - auto dx = framework::EigenVector::Flatten(*dX); - auto place = context.GetEigenDevice(); +template +struct STanhGradFunctor : public BaseActivationFunctor { + float scale_a; + float scale_b; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { auto temp = (scale_a * x).tanh() * (scale_a * x).tanh(); - dx.device(place) = dy * scale_a * scale_b * (static_cast(1) - temp); + dx.device(d) = dy * scale_a * scale_b * (static_cast(1) - temp); } }; } // namespace operators } // namespace paddle + +#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ + __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ + __macro(exp, ExpFunctor, ExpGradFunctor); \ + __macro(relu, ReluFunctor, ReluGradFunctor); \ + __macro(tanh, TanhFunctor, TanhGradFunctor); \ + __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ + __macro(abs, AbsFunctor, AbsGradFunctor); \ + __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ + __macro(log, LogFunctor, LogGradFunctor); \ + __macro(square, SquareFunctor, SquareGradFunctor); \ + __macro(brelu, BReluFunctor, BReluGradFunctor); \ + __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ + __macro(pow, PowFunctor, PowGradFunctor); \ + __macro(stanh, STanhFunctor, STanhGradFunctor); \ + __macro(softsign, SoftsignFunctor, SoftsignGradFunctor)