From 13c99434cdfeff416c4e0027aebb4f5600aec56f Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 23 Mar 2022 10:29:12 +0800 Subject: [PATCH] [Phi]Move log/log2/log10/log1p Kernels to Phi (#40785) * move activation * fix bugs when run ce --- paddle/fluid/framework/operator.cc | 10 +- paddle/fluid/operators/activation_op.cc | 12 +- paddle/fluid/operators/activation_op.h | 151 +----------- paddle/fluid/operators/activation_op.kps | 112 +-------- paddle/phi/kernels/activation_grad_kernel.h | 12 + paddle/phi/kernels/activation_kernel.h | 4 + .../phi/kernels/cpu/activation_grad_kernel.cc | 9 + paddle/phi/kernels/cpu/activation_kernel.cc | 8 + paddle/phi/kernels/funcs/activation_functor.h | 220 ++++++++++++++++++ .../phi/kernels/gpu/activation_grad_kernel.cu | 15 ++ paddle/phi/kernels/gpu/activation_kernel.cu | 10 +- .../phi/kernels/impl/activation_grad_impl.h | 18 ++ paddle/phi/ops/compat/activation_sig.cc | 16 ++ 13 files changed, 332 insertions(+), 265 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 42fbeb5d29c..15777c287b4 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1122,7 +1122,15 @@ static void CheckTensorNANOrInf(const std::string& op_type, bool OperatorWithKernel::SupportsMKLDNN( const proto::VarType::Type data_type) const { - auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); + auto op_kernel_iter = OperatorWithKernel::AllOpKernels().find(type_); + if (op_kernel_iter == OperatorWithKernel::AllOpKernels().end()) { + VLOG(6) << "Warning: " << type_ << " don't find its MKLDNN Kernel in Fluid " + "Registered Kernels. And We don't " + "search its kernels in phi lib, " + "SupportsMKLDNN() return false."; + return false; + } + auto& op_kernels = op_kernel_iter->second; return std::any_of(op_kernels.begin(), op_kernels.end(), [data_type](OpKernelMap::const_reference kern_pair) { return platform::is_cpu_place(kern_pair.first.place_) && diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 845d0ed073b..8f7b62a2c9d 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1496,6 +1496,9 @@ REGISTER_ACTIVATION_OP(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); REGISTER_ACTIVATION_OP(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); +REGISTER_ACTIVATION_OP(log2, Log2, Log2Functor, Log2GradFunctor); +REGISTER_ACTIVATION_OP(log10, Log10, Log10Functor, Log10GradFunctor); +REGISTER_ACTIVATION_OP(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); /* ========================== sigmoid register ============================= */ @@ -1867,15 +1870,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_ACTIVATION_CPU_KERNEL(log, Log, LogFunctor, LogGradFunctor); - -REGISTER_OP_CPU_KERNEL( - log_grad_grad, ops::LogDoubleGradKernel>, - ops::LogDoubleGradKernel>, - ops::LogDoubleGradKernel>); /* ========================================================================== */ /* ========================== register checkpoint ===========================*/ diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index f1984af6e15..7db5675c16b 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -281,6 +281,11 @@ USE_PHI_DOUBLE_GRAD_FUNCTOR(Sigmoid) USE_PHI_TRIPLE_GRAD_FUNCTOR(Sigmoid) USE_PHI_FUNCTOR(LogSigmoid) USE_PHI_FUNCTOR(HardSigmoid) +USE_PHI_FUNCTOR(Log) +USE_PHI_DOUBLE_GRAD_FUNCTOR(Log) +USE_PHI_FUNCTOR(Log2) +USE_PHI_FUNCTOR(Log10) +USE_PHI_FUNCTOR(Log1p) template using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor; @@ -448,88 +453,6 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor { } }; -// log(x) = natural logarithm of x -template -struct LogFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.log(); - } -}; - -template -struct LogGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * (static_cast(1) / x); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -// log2(x) = logarithm to the base 2 of the elements of x -template -struct Log2Functor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.log() / static_cast(log(2)); - } -}; - -// the gradient of log2(x) is 1/(x*ln(2)) -template -struct Log2GradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * static_cast(1) / (x * static_cast(log(2))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -// log10(x) = logarithm to the base 10 of the elements of x -template -struct Log10Functor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.log() / static_cast(log(10)); - } -}; - -// the gradient of log10(x) is 1/(x*ln(10)) -template -struct Log10GradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * static_cast(1) / (x * static_cast(log(10))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -// log1p(x) = natural logarithm of x+1 -template -struct Log1pFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = (static_cast(1) + x).log(); - } -}; - -template -struct Log1pGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * (static_cast(1) / (x + static_cast(1))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { @@ -1197,37 +1120,6 @@ class SquareDoubleGradKernel } }; -template -class LogDoubleGradKernel - : public SquareDoubleGradKernel {}; - -template -class ELUDoubleGradKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& ctx) const override { - const framework::Tensor *X, *ddX, *dOut; - X = ddX = dOut = nullptr; - framework::Tensor *dX, *ddOut; - dX = ddOut = nullptr; - - ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); - - if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); - if (ddOut) ddOut->mutable_data(ctx.GetPlace()); - - auto& place = ctx.template device_context(); - - Functor functor; - auto attrs = functor.GetAttrs(); - for (auto& attr : attrs) { - *attr.second = ctx.Attr(attr.first); - } - functor(place, X, ddX, ddOut, dOut, dX); - } -}; - template class CELUDoubleGradKernel : public framework::OpKernel { @@ -1522,36 +1414,6 @@ class LogitGradKernel : public framework::OpKernel { } }; -template -struct LogGradGradFunctor : public BaseActivationFunctor { - template - void operator()(const Device& dev, const framework::Tensor* X, - const framework::Tensor* ddX, framework::Tensor* ddOut, - const framework::Tensor* dOut, framework::Tensor* dX) const { - auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad")); - auto x = framework::EigenVector::Flatten( - GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad")); - // ddout = ddx / x; dx = -(dout / x) * (ddx / x) - // calculate dx first, so ddout can inplace ddx - if (dX) { - auto dout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad")); - auto dx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad")); - dx.device(*d) = dout * static_cast(-1) * ddx / (x * x); - } - if (ddOut) { - auto ddout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad")); - ddout.device(*d) = ddx * static_cast(1) / x; - } - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - } // namespace operators } // namespace paddle @@ -1560,9 +1422,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor { __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ - __macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \ - __macro(log2, Log2, Log2Functor, Log2GradFunctor); \ - __macro(log10, Log10, Log10Functor, Log10GradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ diff --git a/paddle/fluid/operators/activation_op.kps b/paddle/fluid/operators/activation_op.kps index 7c1b2880801..bb08cee5bcd 100644 --- a/paddle/fluid/operators/activation_op.kps +++ b/paddle/fluid/operators/activation_op.kps @@ -131,27 +131,6 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor { } }; -template -struct CudaLogFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // log(x) = log(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(log(x)); - } -}; - -template -struct CudaLogGradFunctor : public BaseActivationFunctor { - // dx = dout / x - __device__ __forceinline__ T operator()(const T dout, const T x) const { - return dout / x; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CudaSquareFunctor : public BaseActivationFunctor { // square(x) = x * x @@ -220,78 +199,6 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor { } }; -template -struct CudaLog1pFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - - // log1p(x) = log(1 + x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(log(one + x)); - } -}; - -template -struct CudaLog1pGradFunctor : public BaseActivationFunctor { - T one = static_cast(1.0f); - - // dx = dout / (1 + x) - __device__ __forceinline__ T operator()(const T dout, const T x) const { - return dout / (one + x); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -template -struct CudaLog2Functor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // log2(x) = log2(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(log2(x)); - } -}; - -template -struct CudaLog2GradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - T log_two = static_cast(log(static_cast(2.0f))); - - // dx = dout / (x * log(2)) - __device__ __forceinline__ T operator()(const T dout, const T x) const { - return dout / (x * log_two); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -template -struct CudaLog10Functor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // log10(x) = log10(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(log10(x)); - } -}; - -template -struct CudaLog10GradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - T log_ten = static_cast(log(static_cast(10.0f))); - - // dx = dout / (x * log(10)) - __device__ __forceinline__ T operator()(const T dout, const T x) const { - return dout / (x * log_ten); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CudaSoftReluFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -773,6 +680,10 @@ USE_PHI_FUNCTOR(CudaELU) USE_PHI_FUNCTOR(CudaSigmoid) USE_PHI_FUNCTOR(CudaLogSigmoid) USE_PHI_FUNCTOR(CudaHardSigmoid) +USE_PHI_FUNCTOR(CudaLog) +USE_PHI_FUNCTOR(CudaLog2) +USE_PHI_FUNCTOR(CudaLog10) +USE_PHI_FUNCTOR(CudaLog1p) template using CudaELUGradNegativeAlphaFunctor = @@ -975,18 +886,6 @@ REGISTER_OP_CUDA_KERNEL( ops::CudaExpm1GradFunctor>); /* ========================================================================== */ -/* ========================== Log register ==================================*/ -REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor); - -REGISTER_OP_CUDA_KERNEL( - log_grad_grad, ops::LogDoubleGradKernel>, - ops::LogDoubleGradKernel>, - ops::LogDoubleGradKernel>); -/* ========================================================================== */ - #define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ CudaSoftShrinkGradFunctor); \ @@ -995,9 +894,6 @@ REGISTER_OP_CUDA_KERNEL( __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \ __macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \ CudaReciprocalGradFunctor); \ - __macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor); \ - __macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor); \ - __macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor); \ __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \ __macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor); \ __macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor); \ diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index 241a80d85ea..6ad28f348f2 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -135,6 +135,14 @@ void SigmoidTripleGradKernel(const Context& dev_ctx, DenseTensor* d_dout, DenseTensor* d_ddx); +template +void LogDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + DenseTensor* dx, + DenseTensor* ddout); + DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos); @@ -149,6 +157,10 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Atanh); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Silu); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid); +DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log); +DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log2); +DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log10); +DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log1p); DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu); DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh); diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index dbc63a636ed..785d1089f06 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -56,6 +56,10 @@ DECLARE_ACTIVATION_KERNEL(TanhShrink) DECLARE_ACTIVATION_KERNEL(Silu) DECLARE_ACTIVATION_KERNEL(Sigmoid) DECLARE_ACTIVATION_KERNEL(LogSigmoid) +DECLARE_ACTIVATION_KERNEL(Log) +DECLARE_ACTIVATION_KERNEL(Log2) +DECLARE_ACTIVATION_KERNEL(Log10) +DECLARE_ACTIVATION_KERNEL(Log1p) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index c5822615962..0776e570e9c 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -121,6 +121,10 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, AtanhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, TanhShrinkGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, SiluGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid, LogSigmoidGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Log, LogGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Log2, Log2GradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Log10, Log10GradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Log1p, Log1pGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor); @@ -233,3 +237,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_sigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(logsigmoid_grad, LogSigmoidGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 1d7b77ea444..c8709261d2c 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -74,6 +74,10 @@ DEFINE_CPU_ACTIVATION_KERNEL(TanhShrink, TanhShrinkFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Silu, SiluFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Sigmoid, SigmoidFunctor) DEFINE_CPU_ACTIVATION_KERNEL(LogSigmoid, LogSigmoidFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Log2, Log2Functor) +DEFINE_CPU_ACTIVATION_KERNEL(Log10, Log10Functor) +DEFINE_CPU_ACTIVATION_KERNEL(Log1p, Log1pFunctor) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, @@ -118,3 +122,7 @@ PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) +PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) +PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) +PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) +PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 6c5ffbd06e3..6e536bd00a4 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1223,6 +1223,133 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor { } }; +// log(x) = natural logarithm of x +template +struct LogFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.log(); + } +}; + +template +struct LogGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast(1) / x); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +// log2(x) = logarithm to the base 2 of the elements of x +template +struct Log2Functor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.log() / static_cast(log(2)); + } +}; + +// the gradient of log2(x) is 1/(x*ln(2)) +template +struct Log2GradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(1) / (x * static_cast(log(2))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +// log10(x) = logarithm to the base 10 of the elements of x +template +struct Log10Functor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.log() / static_cast(log(10)); + } +}; + +// the gradient of log10(x) is 1/(x*ln(10)) +template +struct Log10GradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(1) / (x * static_cast(log(10))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +// log1p(x) = natural logarithm of x+1 +template +struct Log1pFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = (static_cast(1) + x).log(); + } +}; + +template +struct Log1pGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast(1) / (x + static_cast(1))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct LogGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* ddX, + DenseTensor* ddOut, + const DenseTensor* dOut, + DenseTensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad")); + auto x = EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad")); + // ddout = ddx / x; dx = -(dout / x) * (ddx / x) + // calculate dx first, so ddout can inplace ddx + if (dX) { + auto dout = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad")); + auto dx = EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad")); + dx.device(*d) = dout * static_cast(-1) * ddx / (x * x); + } + if (ddOut) { + auto ddout = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad")); + ddout.device(*d) = ddx * static_cast(1) / x; + } + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) template struct CudaReluFunctor : public BaseActivationFunctor { @@ -1970,6 +2097,99 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor { } }; +template +struct CudaLogFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // log(x) = log(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(log(x)); + } +}; + +template +struct CudaLogGradFunctor : public BaseActivationFunctor { + // dx = dout / x + __device__ __forceinline__ T operator()(const T dout, const T x) const { + return dout / x; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaLog1pFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // log1p(x) = log(1 + x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(log(one + x)); + } +}; + +template +struct CudaLog1pGradFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // dx = dout / (1 + x) + __device__ __forceinline__ T operator()(const T dout, const T x) const { + return dout / (one + x); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaLog2Functor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // log2(x) = log2(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(log2(x)); + } +}; + +template +struct CudaLog2GradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + T log_two = static_cast(log(static_cast(2.0f))); + + // dx = dout / (x * log(2)) + __device__ __forceinline__ T operator()(const T dout, const T x) const { + return dout / (x * log_two); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaLog10Functor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // log10(x) = log10(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(log10(x)); + } +}; + +template +struct CudaLog10GradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + T log_ten = static_cast(log(static_cast(10.0f))); + + // dx = dout / (x * log(10)) + __device__ __forceinline__ T operator()(const T dout, const T x) const { + return dout / (x * log_ten); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + #endif } // namespace funcs diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index c912d0c4686..3cc41555a89 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -177,6 +177,10 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, CudaAtanhGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, CudaTanhShrinkGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, CudaSiluGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid, CudaLogSigmoidGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log, CudaLogGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log2, CudaLog2GradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log10, CudaLog10GradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log1p, CudaLog1pGradFunctor); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, CudaLeakyReluGradFunctor, @@ -300,3 +304,14 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_sigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(logsigmoid_grad, LogSigmoidGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) +PD_REGISTER_KERNEL(log_double_grad, + GPU, + ALL_LAYOUT, + phi::LogDoubleGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 6b598c764de..fb4e2e07b21 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/impl/activation_grad_impl.h" +#include "paddle/phi/kernels/impl/activation_impl.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" @@ -93,6 +93,10 @@ DEFINE_GPU_ACTIVATION_KERNEL(TanhShrink, CudaTanhShrinkFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Silu, CudaSiluFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Sigmoid, CudaSigmoidFunctor) DEFINE_GPU_ACTIVATION_KERNEL(LogSigmoid, CudaLogSigmoidFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Log2, CudaLog2Functor) +DEFINE_GPU_ACTIVATION_KERNEL(Log10, CudaLog10Functor) +DEFINE_GPU_ACTIVATION_KERNEL(Log1p, CudaLog1pFunctor) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, @@ -164,3 +168,7 @@ PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) +PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) +PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) +PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) +PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index 7d6b6dc72ea..7ef8a0887c7 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -275,4 +275,22 @@ void SigmoidTripleGradKernel(const Context& dev_ctx, d_ddx); } +template +void LogDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + DenseTensor* dx, + DenseTensor* ddout) { + if (dx) { + dx->Resize(x.dims()); + dev_ctx.template Alloc(dx); + } + if (ddout) { + dev_ctx.template Alloc(ddout); + } + funcs::LogGradGradFunctor functor; + functor(dev_ctx, &x, &ddx, ddout, &dout, dx); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index 7ae0dc45c5e..8b4884e35b6 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -57,6 +57,10 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardShrink, "hard_shrink", "threshold"); DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(TanhShrink, "tanh_shrink", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Silu, "silu", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(LogSigmoid, "logsigmoid", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log10, "log10", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log1p, "log1p", ); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", ); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Tanh, "tanh", ); // NOLINT @@ -125,6 +129,12 @@ KernelSignature EluDoubleGradOpArgumentMapping( "elu_double_grad", {"X", "DOut", "DDX"}, {"alpha"}, {"DX", "DDOut"}); } +KernelSignature LogDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "log_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"}); +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad); @@ -134,6 +144,7 @@ PD_REGISTER_BASE_KERNEL_NAME(softshrink, soft_shrink); PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad); PD_REGISTER_BASE_KERNEL_NAME(elu_grad_grad, elu_double_grad); PD_REGISTER_BASE_KERNEL_NAME(sigmoid_grad_grad, sigmoid_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(log_grad_grad, log_double_grad); PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping); @@ -181,3 +192,8 @@ PD_REGISTER_ARG_MAPPING_FN(logsigmoid_grad, phi::LogSigmoidGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(hard_sigmoid_grad, phi::HardSigmoidGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(log_grad, phi::LogGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(log_grad_grad, phi::LogDoubleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(log2_grad, phi::Log2GradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(log10_grad, phi::Log10GradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(log1p_grad, phi::Log1pGradOpArgumentMapping); -- GitLab