diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 6be872b028ca0ba6a11b803da7222097e8dcd741..6905f3d79546e5dcadf2aeec64f85e9a72937e16 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1659,15 +1659,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_ACTIVATION_CPU_KERNEL(celu, CELU, CELUFunctor, CELUGradFunctor); -REGISTER_OP_CPU_KERNEL( - celu_grad_grad, ops::CELUDoubleGradKernel>, - ops::CELUDoubleGradKernel>, - ops::CELUDoubleGradKernel>); - /* ========================================================================== */ /* =========================== sqrt register ============================= */ @@ -1687,13 +1678,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL( - sqrt_grad_grad, ops::SqrtDoubleGradKernel>, - ops::SqrtDoubleGradKernel>, - ops::SqrtDoubleGradKernel>); /* ========================================================================== */ /* =========================== rsqrt register ============================= @@ -1714,14 +1698,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL( - rsqrt_grad_grad, - ops::RsqrtDoubleGradKernel>, - ops::RsqrtDoubleGradKernel>, - ops::RsqrtDoubleGradKernel>); /* ========================================================================== */ /* ========================== square register ============================ */ @@ -1742,18 +1718,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL( - square_grad_grad, - ops::SquareDoubleGradKernel>, - ops::SquareDoubleGradKernel>, - ops::SquareDoubleGradKernel>, - ops::SquareDoubleGradKernel>, - ops::SquareDoubleGradKernel>); /* ========================================================================== */ /* ========================== pow register ============================ */ diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 8214b733f86da5240545c5effdc132b0d50c3144..5f3916a65e79274e717ed51a91224c5248f171f8 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -296,9 +296,14 @@ USE_PHI_FUNCTOR(Mish) USE_PHI_FUNCTOR(STanh) USE_PHI_FUNCTOR(Reciprocal) USE_PHI_FUNCTOR(Square) +USE_PHI_DOUBLE_GRAD_FUNCTOR(Square) USE_PHI_FUNCTOR(Sqrt) +USE_PHI_DOUBLE_GRAD_FUNCTOR(Sqrt) USE_PHI_FUNCTOR(Rsqrt) +USE_PHI_DOUBLE_GRAD_FUNCTOR(Rsqrt) USE_PHI_FUNCTOR(Softplus) +USE_PHI_FUNCTOR(CELU) +USE_PHI_DOUBLE_GRAD_FUNCTOR(CELU) template using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor; @@ -331,68 +336,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor; template using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor; -template -struct SqrtGradGradFunctor : public BaseActivationFunctor { - template - void operator()(const Device& dev, const framework::Tensor* Out, - const framework::Tensor* ddX, framework::Tensor* ddOut, - framework::Tensor* dOut, const framework::Tensor* dX) const { - auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); - auto out = framework::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); - // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx - // calculate dy first, so ddy can inplace ddx - if (dOut) { - auto dx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); - auto dout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); - dout.device(*d) = dx * ddx * static_cast(-1) / out; - } - if (ddOut) { - auto ddout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); - ddout.device(*d) = ddx * static_cast(0.5) / out; - } - } - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -struct RsqrtGradGradFunctor : public BaseActivationFunctor { - template - void operator()(const Device& dev, const framework::Tensor* Out, - const framework::Tensor* ddX, framework::Tensor* ddOut, - framework::Tensor* dOut, const framework::Tensor* dX) const { - auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad")); - auto out = framework::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad")); - - // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx - if (dOut) { - auto dx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad")); - auto dout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad")); - dout.device(*d) = (static_cast(3.0) / out) * dx * ddx; - } - if (ddOut) { - auto ddout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad")); - ddout.device(*d) = ddx * static_cast(-0.5) * out * out * out; - } - } - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - // relu6(x) = min(max(0, x), 6) template struct Relu6Functor : public BaseActivationFunctor { @@ -498,51 +441,6 @@ class ELUGradKernel : public framework::OpKernel { } }; -template -struct CELUFunctor : public BaseActivationFunctor { - float alpha; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - - template - void operator()(Device d, X x, Out out) const { - out.device(d) = - (x < static_cast(0)) - .select(static_cast(alpha) * - ((x / static_cast(alpha)).exp() - static_cast(1)), - x); - } -}; - -template -struct CELUGradFunctor : public BaseActivationFunctor { - float alpha; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp_a_pos = static_cast(alpha > 0); - auto temp_a_neg = static_cast(alpha <= 0); - auto temp_x_pos = (x > static_cast(0)).template cast(); - auto temp_x_neg = (x <= static_cast(0)).template cast(); - - // dx = dout, if alpha > 0 and x > 0 - // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 - // dx = dout , if alpha < 0 and x > 0 - // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 - dx.device(d) = - dout * temp_a_pos * temp_x_pos + - dout * (x / static_cast(alpha)).exp() * temp_a_pos * temp_x_neg + - dout * temp_a_neg * temp_x_pos + - dout * (x / static_cast(alpha)).exp() * temp_a_neg * temp_x_neg; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct AbsGradGradFunctor : public BaseActivationFunctor { template @@ -564,74 +462,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CELUGradGradFunctor : public BaseActivationFunctor { - float alpha; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - template - void operator()(const Device& dev, const framework::Tensor* X, - const framework::Tensor* ddX, framework::Tensor* ddOut, - const framework::Tensor* dOut, framework::Tensor* dX) const { - auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad")); - auto x = framework::EigenVector::Flatten( - GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad")); - - if (dX) { - auto dx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad")); - auto dout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad")); - dx.device(*d) = ddx * dout / static_cast(alpha) * - (x / static_cast(alpha)).exp() * - (x <= static_cast(0)).template cast(); - } - - if (ddOut) { - auto ddout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad")); - ddout.device(*d) = ddx * - ((x > static_cast(0)).template cast() + - (x / static_cast(alpha)).exp() * - (x <= static_cast(0)).template cast()) - .template cast(); - } - } - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -template -struct SquareGradGradFunctor : public BaseActivationFunctor { - template - void operator()(const Device& dev, const framework::Tensor* X, - const framework::Tensor* ddX, framework::Tensor* ddOut, - const framework::Tensor* dOut, framework::Tensor* dX) const { - auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad")); - auto x = framework::EigenVector::Flatten( - GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad")); - // square GradGrad: ddy=2x*ddx, dx=2dy*ddx - // calculate dx first, so ddy can inplace ddx - if (dX) { - auto dx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad")); - auto dout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad")); - dx.device(*d) = ddx * static_cast(2) * dout; - } - if (ddOut) { - auto ddout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad")); - ddout.device(*d) = ddx * static_cast(2) * x; - } - } - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - // TODO(dengkaipeng): double gradient calculation for Square/Sqrt need // DOut(dy) as input(not output), tensor extraction is different from // others. Impliment extraction kernel separately here. @@ -675,29 +505,6 @@ inline void ExtractDoubleGradTensorWithInputDOut( } } -template -class SquareDoubleGradKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& ctx) const override { - const framework::Tensor *X, *ddX, *dOut; - X = ddX = dOut = nullptr; - framework::Tensor *dX, *ddOut; - dX = ddOut = nullptr; - - ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); - - if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); - if (ddOut) ddOut->mutable_data(ctx.GetPlace()); - - auto& place = ctx.template device_context(); - - Functor functor; - functor(place, X, ddX, ddOut, dOut, dX); - } -}; - template struct SoftsignFunctor : public BaseActivationFunctor { template @@ -721,153 +528,6 @@ struct SoftsignGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -class CELUDoubleGradKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& ctx) const override { - const framework::Tensor *X, *ddX, *dOut; - X = ddX = dOut = nullptr; - framework::Tensor *dX, *ddOut; - dX = ddOut = nullptr; - - ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); - - if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); - if (ddOut) ddOut->mutable_data(ctx.GetPlace()); - - auto& place = ctx.template device_context(); - - Functor functor; - auto attrs = functor.GetAttrs(); - for (auto& attr : attrs) { - *attr.second = ctx.Attr(attr.first); - } - functor(place, X, ddX, ddOut, dOut, dX); - } -}; - -template -class SqrtDoubleGradKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& ctx) const override { - const framework::Tensor *Out, *dX, *ddX; - Out = dX = ddX = nullptr; - framework::Tensor *ddOut, *dOut; - ddOut = dOut = nullptr; - - // extract ddx(input), ddout(output) - auto ddx_var = ctx.InputVar("DDX"); - auto ddo_var = ctx.OutputVar("DDOut"); - PADDLE_ENFORCE_NOT_NULL( - ddx_var, platform::errors::NotFound( - "Cannot get input Variable DDX, variable name = %s", - ctx.InputName("DDX"))); - ddX = ctx.Input("DDX"); - if (ddo_var) { - ddOut = ctx.Output("DDOut"); - } - PADDLE_ENFORCE_NOT_NULL( - ddX, platform::errors::NotFound( - "Cannot get input Variable DDX, variable name = %s", - ctx.InputName("DDX"))); - - // extract out(input), dout(output) - auto out_var = ctx.InputVar("Out"); - PADDLE_ENFORCE_NOT_NULL( - out_var, platform::errors::NotFound( - "Cannot get input Variable Out, variable name = %s", - ctx.InputName("Out"))); - auto dout_var = ctx.OutputVar("DOut"); - Out = ctx.Input("Out"); - if (dout_var) { - dOut = ctx.Output("DOut"); - } - - // extract dx(input) - auto dx_var = ctx.InputVar("DX"); - PADDLE_ENFORCE_NOT_NULL( - dx_var, platform::errors::NotFound( - "Cannot get input Variable DX, variable name = %s", - ctx.InputName("DX"))); - if (dx_var) { - dX = ctx.Input("DX"); - } - - if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); - if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); - - auto& place = ctx.template device_context(); - - Functor functor; - functor(place, Out, ddX, ddOut, dOut, dX); - } -}; - -// rsqrt Grad: dx = -0.5 * dy * y * y * y -// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx -template -class RsqrtDoubleGradKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& ctx) const override { - const framework::Tensor *Out, *dX, *ddX; - Out = dX = ddX = nullptr; - framework::Tensor *ddOut, *dOut; - ddOut = dOut = nullptr; - - // extract ddx(input), ddout(output) - auto ddx_var = ctx.InputVar("DDX"); - auto ddo_var = ctx.OutputVar("DDOut"); - PADDLE_ENFORCE_NOT_NULL( - ddx_var, platform::errors::NotFound( - "Cannot get input Variable DDX, variable name = %s", - ctx.InputName("DDX"))); - ddX = ctx.Input("DDX"); - if (ddo_var) { - ddOut = ctx.Output("DDOut"); - } - PADDLE_ENFORCE_NOT_NULL( - ddX, platform::errors::NotFound( - "Cannot get input Variable DDX, variable name = %s", - ctx.InputName("DDX"))); - - // extract out(input), dout(output) - auto out_var = ctx.InputVar("Out"); - PADDLE_ENFORCE_NOT_NULL( - out_var, platform::errors::NotFound( - "Cannot get input Variable Out, variable name = %s", - ctx.InputName("Out"))); - auto dout_var = ctx.OutputVar("DOut"); - Out = ctx.Input("Out"); - if (dout_var) { - dOut = ctx.Output("DOut"); - } - - // extract dx(input) - auto dx_var = ctx.InputVar("DX"); - PADDLE_ENFORCE_NOT_NULL( - dx_var, platform::errors::NotFound( - "Cannot get input Variable DX, variable name = %s", - ctx.InputName("DX"))); - if (dx_var) { - dX = ctx.Input("DX"); - } - - if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); - if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); - - auto& place = ctx.template device_context(); - - Functor functor; - functor(place, Out, ddX, ddOut, dOut, dX); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/activation_op.kps b/paddle/fluid/operators/activation_op.kps index e33351520e6ddd1a1591ce0fe3193698276f2b89..7298a05827889c1bc63a6d5035a6fe2fe6aaf00a 100644 --- a/paddle/fluid/operators/activation_op.kps +++ b/paddle/fluid/operators/activation_op.kps @@ -126,59 +126,6 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaCELUFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; - CT zero = static_cast(0.0f); - CT one = static_cast(1.0f); - float alpha; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - - // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1)) - __device__ __forceinline__ T operator()(const T arg_x) const { - CT x = static_cast(arg_x); - CT temp = static_cast(alpha) * (exp(x / static_cast(alpha)) - one); - CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp); - return static_cast(res); - } -}; - -template -struct CudaCELUGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType zero = static_cast(0.0f); - MPType one = static_cast(1.0f); - float alpha; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - - // dx = dout, if alpha > 0 and x > 0 - // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 - // dx = dout , if alpha < 0 and x > 0 - // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - MPType a = static_cast(alpha); - MPType temp_a_pos = static_cast(alpha > 0.0f); - MPType temp_a_neg = static_cast(alpha <= 0.0f); - MPType temp_x_pos = static_cast(x > zero); - MPType temp_x_neg = static_cast(x <= zero); - return static_cast( - dout * - (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) + - temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template class ActivationCudaKernel : public framework::OpKernel { @@ -357,79 +304,35 @@ namespace plat = paddle::platform; ops::ActivationGradCudaKernel>); -/* ========================================================================== */ - -/* ======================== celu register ============================ */ -REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor, - CudaCELUGradFunctor); - REGISTER_OP_CUDA_KERNEL( - celu_grad_grad, ops::CELUDoubleGradKernel>, - ops::CELUDoubleGradKernel>, - ops::CELUDoubleGradKernel>); -/* ========================================================================== */ - -/* =========================== sqrt register ============================= */ - -REGISTER_OP_CUDA_KERNEL( - sqrt_grad_grad, - ops::SqrtDoubleGradKernel>, - ops::SqrtDoubleGradKernel>, - ops::SqrtDoubleGradKernel>, - ops::SqrtDoubleGradKernel>); -/* ========================================================================== */ - -/* =========================== rsqrt register ============================= - */ - + relu6, ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( - rsqrt_grad_grad, - ops::RsqrtDoubleGradKernel>, - ops::RsqrtDoubleGradKernel>, - ops::RsqrtDoubleGradKernel>); -/* ========================================================================== */ - -/* =========================== square register ============================ */ - -REGISTER_OP_CUDA_KERNEL( - square_grad_grad, - ops::SquareDoubleGradKernel>, - ops::SquareDoubleGradKernel>, - ops::SquareDoubleGradKernel>, - ops::SquareDoubleGradKernel>, - ops::SquareDoubleGradKernel>, - ops::SquareDoubleGradKernel>); -/* ========================================================================== */ - -/* ========================== logit register ============================ */ -namespace ops = paddle::operators; -/* ========================================================================== */ - -/* ========================== exp register ============================ */ -/* ========================================================================== */ - -/* ========================== expm1 register ============================ */ -/* ========================================================================== */ + relu6_grad, ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>); #define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \ - __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor); \ __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor); FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL) @@ -452,13 +355,14 @@ REGISTER_OP_KERNEL( ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(celu, KP, plat::XPUPlace, - ops::ActivationCudaKernel>); +REGISTER_OP_KERNEL( + celu, KP, plat::XPUPlace, + ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( celu_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); + phi::funcs::CudaCELUGradFunctor>); REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace, ops::ActivationCudaKernel +void SqrtDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& dx, + const DenseTensor& ddx, + DenseTensor* dout, + DenseTensor* ddout); + +template +void RsqrtDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& dx, + const DenseTensor& ddx, + DenseTensor* dout, + DenseTensor* ddout); + +template +void CeluDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + float alpha, + DenseTensor* dx, + DenseTensor* ddout); + +template +void SquareDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + DenseTensor* dx, + DenseTensor* ddout); + template void HardSwishGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -200,6 +233,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps); +DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max); diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 8a40bacd395c447b3c681a60ac84ca4b436d5d2f..a14f732b6c8b6a4fd0699f028f9491499fc0ed26 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -78,6 +78,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Swish, beta) +DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(celu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index d2b816de8fd2b0bec43f33aac96894b4816faedf..894f959c13193e1ce8d3e2f2b5d2735f3a1a26ee 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -167,6 +167,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishGradFunctor, beta); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, MishGradFunctor, threshold); +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, CELUGradFunctor, alpha); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, BReluGradFunctor, @@ -281,6 +282,10 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad, PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad, LeakyReluDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad, + SqrtDoubleGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad, + RsqrtDoubleGradKernel) PD_REGISTER_KERNEL(tanh_triple_grad, CPU, @@ -317,6 +322,15 @@ PD_REGISTER_KERNEL(square_grad, double, int, int64_t) {} +PD_REGISTER_KERNEL(square_double_grad, + CPU, + ALL_LAYOUT, + phi::SquareDoubleGradKernel, + float, + double, + phi::dtype::float16, + int, + int64_t) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel) @@ -332,6 +346,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(celu_double_grad, + CeluDoubleGradKernel) PD_REGISTER_KERNEL(pow_grad, CPU, diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index fe0643286cb027c8f8af947cfec30e74496e6c12..165627839a3083a4bd03a1e60c5a20250a56c74c 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -90,19 +90,19 @@ DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) - DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, ThresholdedReluFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold) -DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max) -DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, STanhFunctor, scale_a, scale_b) -DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, SoftplusFunctor, beta, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishFunctor, beta) +DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CELUFunctor, alpha) +DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max) +DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, STanhFunctor, scale_a, scale_b) +DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, SoftplusFunctor, beta, threshold) DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, HardSigmoidFunctor, slope, @@ -181,5 +181,6 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) +PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel) PD_REGISTER_KERNEL( pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 315d540541f925230b39171cdd65cbcdaf9b9d09..f80117ccec7998ce927533a0e59da43b9992cfcb 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1832,6 +1832,196 @@ struct ZeroGradFunctor : public BaseActivationFunctor { } }; +template +struct SqrtGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, + const DenseTensor* Out, + const DenseTensor* dX, + const DenseTensor* ddX, + DenseTensor* dOut, + DenseTensor* ddOut) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); + auto out = EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); + // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx + // calculate dy first, so ddy can inplace ddx + if (dOut) { + auto dx = EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); + auto dout = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); + dout.device(*d) = dx * ddx * static_cast(-1) / out; + } + if (ddOut) { + auto ddout = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); + ddout.device(*d) = ddx * static_cast(0.5) / out; + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +template +struct RsqrtGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, + const DenseTensor* Out, + const DenseTensor* dX, + const DenseTensor* ddX, + DenseTensor* dOut, + DenseTensor* ddOut) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad")); + auto out = EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad")); + + // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx + if (dOut) { + auto dx = EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad")); + auto dout = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad")); + dout.device(*d) = (static_cast(3.0) / out) * dx * ddx; + } + if (ddOut) { + auto ddout = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad")); + ddout.device(*d) = ddx * static_cast(-0.5) * out * out * out; + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +template +struct CELUFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + template + void operator()(Device d, X x, Out out) const { + out.device(d) = + (x < static_cast(0)) + .select(static_cast(alpha) * + ((x / static_cast(alpha)).exp() - static_cast(1)), + x); + } +}; + +template +struct CELUGradFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto temp_a_pos = static_cast(alpha > 0); + auto temp_a_neg = static_cast(alpha <= 0); + auto temp_x_pos = (x > static_cast(0)).template cast(); + auto temp_x_neg = (x <= static_cast(0)).template cast(); + + // dx = dout, if alpha > 0 and x > 0 + // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 + // dx = dout , if alpha < 0 and x > 0 + // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 + dx.device(d) = + dout * temp_a_pos * temp_x_pos + + dout * (x / static_cast(alpha)).exp() * temp_a_pos * temp_x_neg + + dout * temp_a_neg * temp_x_pos + + dout * (x / static_cast(alpha)).exp() * temp_a_neg * temp_x_neg; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CELUGradGradFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* dOut, + const DenseTensor* ddX, + DenseTensor* dX, + DenseTensor* ddOut) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad")); + auto x = EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad")); + + if (dX) { + auto dx = EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad")); + auto dout = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad")); + dx.device(*d) = ddx * dout / static_cast(alpha) * + (x / static_cast(alpha)).exp() * + (x <= static_cast(0)).template cast(); + } + + if (ddOut) { + auto ddout = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad")); + ddout.device(*d) = ddx * + ((x > static_cast(0)).template cast() + + (x / static_cast(alpha)).exp() * + (x <= static_cast(0)).template cast()) + .template cast(); + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct SquareGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* dOut, + const DenseTensor* ddX, + DenseTensor* dX, + DenseTensor* ddOut) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad")); + auto x = EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad")); + // square GradGrad: ddy=2x*ddx, dx=2dy*ddx + // calculate dx first, so ddy can inplace ddx + if (dX) { + auto dx = EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad")); + auto dout = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad")); + dx.device(*d) = ddx * static_cast(2) * dout; + } + if (ddOut) { + auto ddout = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad")); + ddout.device(*d) = ddx * static_cast(2) * x; + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) template struct CudaReluFunctor : public BaseActivationFunctor { @@ -3091,6 +3281,59 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor { } }; +template +struct CudaCELUFunctor : public BaseActivationFunctor { + using CT = typename phi::dtype::MPTypeTrait::Type; + CT zero = static_cast(0.0f); + CT one = static_cast(1.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1)) + __device__ __forceinline__ T operator()(const T arg_x) const { + CT x = static_cast(arg_x); + CT temp = static_cast(alpha) * (exp(x / static_cast(alpha)) - one); + CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp); + return static_cast(res); + } +}; + +template +struct CudaCELUGradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); + MPType one = static_cast(1.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // dx = dout, if alpha > 0 and x > 0 + // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 + // dx = dout , if alpha < 0 and x > 0 + // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType a = static_cast(alpha); + MPType temp_a_pos = static_cast(alpha > 0.0f); + MPType temp_a_neg = static_cast(alpha <= 0.0f); + MPType temp_x_pos = static_cast(x > zero); + MPType temp_x_neg = static_cast(x <= zero); + return static_cast( + dout * + (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) + + temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + #endif } // namespace funcs diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 944cd2ea1024d391b1da286e7490de657cb61bd2..1479fd494435da0a6ee1e24b577ab724a9c302c3 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -221,6 +221,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, CudaMishGradFunctor, threshold); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, + CudaCELUGradFunctor, + alpha); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, CudaBReluGradFunctor, @@ -351,7 +354,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel) PD_REGISTER_KERNEL(exp_grad, GPU, @@ -396,6 +401,16 @@ PD_REGISTER_KERNEL(square_grad, int64_t, phi::dtype::float16, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(square_double_grad, + GPU, + ALL_LAYOUT, + phi::SquareDoubleGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) @@ -418,6 +433,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_double_grad, CeluDoubleGradKernel) PD_REGISTER_KERNEL(pow_grad, GPU, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 8cc546ba73a067a42f3f658fe849c98d3e62cc8d..8db31c5ed5b799089a7312f541314fbae57ed073 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -118,8 +118,8 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, CudaSwishFunctor, beta) - DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CudaCELUFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Stanh, CudaSTanhFunctor, scale_a, scale_b) @@ -234,6 +234,7 @@ PD_REGISTER_KERNEL(square, int64_t, phi::dtype::float16, phi::dtype::bfloat16) {} + PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) @@ -251,6 +252,7 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) +PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel) PD_REGISTER_KERNEL(pow, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index 12fcac7d62d31afa593ecf4b1b44130fd1717623..04391d2538c89e882d52c029c80b946083c4a2ce 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -335,4 +335,87 @@ void PowGradKernel(const Context& dev_ctx, functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten); } +template +void SqrtDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& dx, + const DenseTensor& ddx, + DenseTensor* dout, + DenseTensor* ddout) { + if (dout) { + dout->Resize(out.dims()); + dev_ctx.template Alloc(dout); + } + if (ddout) { + ddout->Resize(out.dims()); + dev_ctx.template Alloc(ddout); + } + + phi::funcs::SqrtGradGradFunctor functor; + functor(dev_ctx, &out, &dx, &ddx, dout, ddout); +} + +// rsqrt Grad: dx = -0.5 * dy * y * y * y +// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx +template +void RsqrtDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& dx, + const DenseTensor& ddx, + DenseTensor* dout, + DenseTensor* ddout) { + if (dout) { + dout->Resize(out.dims()); + dev_ctx.template Alloc(dout); + } + if (ddout) { + ddout->Resize(out.dims()); + dev_ctx.template Alloc(ddout); + } + + phi::funcs::RsqrtGradGradFunctor functor; + functor(dev_ctx, &out, &dx, &ddx, dout, ddout); +} + +template +void CeluDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + float alpha, + DenseTensor* dx, + DenseTensor* ddout) { + if (dx) { + dx->Resize(x.dims()); + dev_ctx.template Alloc(dx); + } + if (ddout) { + dev_ctx.template Alloc(ddout); + } + + phi::funcs::CELUGradGradFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = alpha; + functor(dev_ctx, &x, &dout, &ddx, dx, ddout); +} + +template +void SquareDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + DenseTensor* dx, + DenseTensor* ddout) { + if (dx) { + dx->Resize(x.dims()); + dev_ctx.template Alloc(dx); + } + if (ddout) { + dev_ctx.template Alloc(ddout); + } + + phi::funcs::SquareGradGradFunctor functor; + functor(dev_ctx, &x, &dout, &ddx, dx, ddout); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index 157eaa279debb05d37bc58ca7c90bed8c7621640..ac0c3021a3273fb82264f6bfa9241673e5d12a12 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -67,6 +67,7 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log10, "log10", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log1p, "log1p", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Celu, "celu", "alpha"); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish, "hard_swish", "threshold" comma "scale" comma @@ -181,6 +182,30 @@ KernelSignature LogDoubleGradOpArgumentMapping( "log_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"}); } +KernelSignature SqrtDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "sqrt_double_grad", {"Out", "DX", "DDX"}, {}, {"DOut", "DDOut"}); +} + +KernelSignature RsqrtDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "rsqrt_double_grad", {"Out", "DX", "DDX"}, {}, {"DOut", "DDOut"}); +} + +KernelSignature CeluDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "celu_double_grad", {"X", "DOut", "DDX"}, {"alpha"}, {"DX", "DDOut"}); +} + +KernelSignature SquareDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "square_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"}); +} + KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("FactorTensor")) { return KernelSignature("pow", {"X"}, {"FactorTensor"}, {"Out"}); @@ -209,6 +234,10 @@ PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad); PD_REGISTER_BASE_KERNEL_NAME(elu_grad_grad, elu_double_grad); PD_REGISTER_BASE_KERNEL_NAME(sigmoid_grad_grad, sigmoid_double_grad); PD_REGISTER_BASE_KERNEL_NAME(log_grad_grad, log_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(sqrt_grad_grad, sqrt_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(rsqrt_grad_grad, rsqrt_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(celu_grad_grad, celu_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(square_grad_grad, square_double_grad); PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping); @@ -229,7 +258,11 @@ PD_REGISTER_ARG_MAPPING_FN(square_grad, phi::SquareGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reciprocal_grad, phi::ReciprocalGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(sqrt_grad, phi::SqrtGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(sqrt_grad_grad, + phi::SqrtDoubleGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(rsqrt_grad, phi::RsqrtGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(rsqrt_grad_grad, + phi::RsqrtDoubleGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(softplus_grad, phi::SoftplusGradOpArgumentMapping); @@ -286,3 +319,8 @@ PD_REGISTER_ARG_MAPPING_FN(floor_grad, phi::FloorGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(ceil_grad, phi::CeilGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(celu_grad, phi::CeluGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(celu_grad_grad, + phi::CeluDoubleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(square_grad_grad, + phi::SquareDoubleGradOpArgumentMapping);