From c552d1acc4a1bb289555cd70d925ae20d10151a2 Mon Sep 17 00:00:00 2001 From: phlrain Date: Wed, 16 Mar 2022 00:39:59 +0000 Subject: [PATCH] add forward case --- paddle/fluid/operators/activation_op.cc | 28 -- paddle/fluid/operators/activation_op.h | 169 +--------- paddle/fluid/operators/activation_op.kps | 146 --------- .../operators/math/selected_rows_functor.cc | 177 ++++++++--- .../operators/math/selected_rows_functor.cu | 196 ++++++++++-- paddle/phi/kernels/CMakeLists.txt | 2 +- paddle/phi/kernels/activation_grad_kernel.h | 1 + paddle/phi/kernels/activation_kernel.h | 14 + .../phi/kernels/cpu/activation_grad_kernel.cc | 10 + paddle/phi/kernels/cpu/activation_kernel.cc | 37 +++ paddle/phi/kernels/funcs/activation_functor.h | 294 ++++++++++++++++++ .../phi/kernels/gpu/activation_grad_kernel.cu | 10 + paddle/phi/kernels/gpu/activation_kernel.cu | 35 +++ paddle/phi/kernels/gpu/clip_by_norm_kernel.cu | 12 +- paddle/phi/kernels/impl/activation_impl.h | 16 + 15 files changed, 740 insertions(+), 407 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 4205f2253a6..a9563db74d4 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1650,9 +1650,6 @@ REGISTER_OPERATOR(logit, ops::LogitOp, ops::LogitOpMaker, ops::LogitGradOpMaker, ops::LogitGradOpMaker); REGISTER_OPERATOR(logit_grad, ops::LogitGradOp); -REGISTER_OP_CPU_KERNEL( - logit, ops::LogitKernel, - ops::LogitKernel); REGISTER_OP_CPU_KERNEL( logit_grad, ops::LogitGradKernel, ops::LogitGradKernel); @@ -1830,24 +1827,6 @@ REGISTER_OPERATOR( REGISTER_OPERATOR(exp_grad, ops::ActivationOpGrad, ops::ActivationGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL(exp, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); -REGISTER_OP_CPU_KERNEL( - exp_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); /* ========================================================================== */ /* ========================== expm1 register ============================ */ @@ -1862,13 +1841,6 @@ REGISTER_OPERATOR( REGISTER_OPERATOR(expm1_grad, ops::ActivationOpGrad, ops::ActivationGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL(expm1, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); REGISTER_OP_CPU_KERNEL( expm1_grad, ops::ActivationGradKernel>, diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index b076db01c22..20adf9136ae 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -273,6 +273,7 @@ USE_PHI_FUNCTOR(Asinh) USE_PHI_FUNCTOR(Acosh) USE_PHI_FUNCTOR(Atanh) USE_PHI_FUNCTOR(Tanh) +USE_PHI_FUNCTOR(Exp) USE_PHI_DOUBLE_GRAD_FUNCTOR(Tanh) USE_PHI_TRIPLE_GRAD_FUNCTOR(Tanh) USE_PHI_FUNCTOR(BRelu) @@ -455,37 +456,6 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -// exp(x) = e^x -template -struct ExpFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.exp(); - } -}; - -template -struct ExpGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * out; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -// expm1(x) = e^x - 1 -template -struct Expm1Functor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.expm1(); - } -}; - template struct Expm1GradFunctor : public BaseActivationFunctor { template { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -// sqrt(x) = x^(1/2) -template -struct SqrtFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.sqrt(); - } -}; - template struct SqrtGradFunctor : public BaseActivationFunctor { template { } }; -// rsqrt(x) = x^(-1/2) -template -struct RsqrtFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.rsqrt(); - } -}; - template struct RsqrtGradFunctor : public BaseActivationFunctor { template { } }; -// reciprocal(x) = 1 / x -template -struct ReciprocalFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = static_cast(1) / x; - } -}; - template struct ReciprocalGradFunctor : public BaseActivationFunctor { template { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -// square(x) = x^2 -template -struct SquareFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.square(); - } -}; - template struct SquareGradFunctor : public BaseActivationFunctor { template { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -// For numerical stability, using the following formula instead of softplus(x) = -// log(1 + exp(x)) -// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta = -// 1, threshold = 20 by default), otherwise x -template -struct SoftplusFunctor : public BaseActivationFunctor { - float beta; - float threshold; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}, {"threshold", &threshold}}; - } - - template - void operator()(Device d, X x, Out out) { - auto x_beta = static_cast(beta) * x; - out.device(d) = (x_beta > static_cast(threshold)) - .select(x, (static_cast(1) + x_beta.exp()).log() / - static_cast(beta)); - } -}; - // For numerical stability, using the following formula instead of // d(softplus(x))/dx = 1 / (1 + exp(-x)) // d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta @@ -939,24 +852,6 @@ struct SoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -// mish(x) = x * tanh(softplus(x)) -// softplus(x) = x, if x > threshold -// = ln(1 + exp(x)), otherwise -template -struct MishFunctor : public BaseActivationFunctor { - float threshold; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - template - void operator()(Device d, X x, Out out) { - auto sp = (x > static_cast(threshold)) - .select(x, (static_cast(1) + x.exp()).log()); - out.device(d) = x * sp.tanh(); - } -}; - // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) // sp = softplus(x) template @@ -979,15 +874,6 @@ struct MishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -// softsign(x) = x / (1 + |x|) -template -struct SoftsignFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) { - out.device(d) = x / (static_cast(1) + x.abs()); - } -}; - // d(softsign(x))/dx = 1 / (1 + |x|)^2 // Taken from https://en.wikipedia.org/wiki/Activation_function template @@ -1198,24 +1084,6 @@ struct PowGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct LogitFunctor { - template - void operator()(Device d, X x, Out out, P p, float eps) const { - // logit(x) = ln(x/(1-x)) - auto tmp_x = - (x.cwiseMin(static_cast(1.0 - eps))).cwiseMax(static_cast(eps)); - - if (!eps) { - out.device(d) = (x < static_cast(0.0) || x > static_cast(1.0)) - .select(p.constant(static_cast(NAN)), - (tmp_x / (static_cast(1) - tmp_x)).log()); - } else { - out.device(d) = (tmp_x / (static_cast(1) - tmp_x)).log(); - } - } -}; - template struct LogitGradFunctor { template @@ -1228,21 +1096,6 @@ struct LogitGradFunctor { } }; -template -struct STanhFunctor : public BaseActivationFunctor { - float scale_a; - float scale_b; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; - } - - template - void operator()(Device d, X x, Out out) const { - out.device(d) = - static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); - } -}; - template struct STanhGradFunctor : public BaseActivationFunctor { float scale_a; @@ -2075,26 +1928,6 @@ class PowGradKernel } }; -template -class LogitKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* out = context.Output("Out"); - auto* in = context.Input("X"); - auto eps = context.Attr("eps"); - out->mutable_data(in->place()); - - auto eigen_out = framework::EigenVector::Flatten(*out); - auto eigen_in = framework::EigenVector::Flatten(*in); - auto& place = - *context.template device_context().eigen_device(); - auto eigen_p = framework::EigenVector::Flatten(*out); - - LogitFunctor functor; - functor(place, eigen_in, eigen_out, eigen_p, eps); - } -}; - template class LogitGradKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/activation_op.kps b/paddle/fluid/operators/activation_op.kps index 256f20db084..6e1ac642d11 100644 --- a/paddle/fluid/operators/activation_op.kps +++ b/paddle/fluid/operators/activation_op.kps @@ -192,14 +192,6 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor { } }; -template -struct CudaReciprocalFunctor : public BaseActivationFunctor { - T one = static_cast(1.0f); - - // reciprocal(x) = 1 / x - __device__ __forceinline__ T operator()(const T x) const { return one / x; } -}; - template struct CudaReciprocalGradFunctor : public BaseActivationFunctor { // dx = -dout * out^2 @@ -212,40 +204,6 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor { } }; -template -struct CudaExpFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // exp(x) = exp(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(exp(x)); - } -}; - -template -struct CudaExpGradFunctor : public BaseActivationFunctor { - // dx = dout * out - __device__ __forceinline__ T operator()(const T dout, const T out) const { - return dout * out; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -struct CudaExpm1Functor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // expm1(x) = expm1(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(expm1(x)); - } -}; - template struct CudaExpm1GradFunctor : public BaseActivationFunctor { // dx = dout * out @@ -279,12 +237,6 @@ struct CudaLogGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaSquareFunctor : public BaseActivationFunctor { - // square(x) = x * x - __device__ __forceinline__ T operator()(const T x) const { return x * x; } -}; - template struct CudaSquareGradFunctor : public BaseActivationFunctor { T two = static_cast(2.0f); @@ -297,17 +249,6 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaSqrtFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // sqrt(x) = sqrt(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(sqrt(x)); - } -}; - template struct CudaSqrtGradFunctor : public BaseActivationFunctor { T one_half = static_cast(0.5f); @@ -322,17 +263,6 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor { } }; -template -struct CudaRsqrtFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // rsqrt(x) = rsqrt(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(rsqrt(x)); - } -}; - template struct CudaRsqrtGradFunctor : public BaseActivationFunctor { T minus_one_half = static_cast(-0.5f); @@ -466,25 +396,6 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor { } }; -template -struct CudaSTanhFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - float scale_a; - float scale_b; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; - } - - // stanh(x) = b * tanh(a * x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - MPType a = static_cast(scale_a); - MPType b = static_cast(scale_b); - return static_cast(b * tanh(a * x)); - } -}; - template struct CudaSTanhGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -510,27 +421,6 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaSoftplusFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - float beta; - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}, {"threshold", &threshold}}; - } - - // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - MPType b = static_cast(beta); - MPType t = static_cast(threshold); - MPType x_beta = x * beta; - return static_cast(x_beta > t ? x : log(one + exp(x_beta)) / b); - } -}; - template struct CudaSoftplusGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -556,16 +446,6 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaSoftsignFunctor : public BaseActivationFunctor { - T one = static_cast(1.0f); - - // softsign(x) = x / (1 + abs(x)) - __device__ __forceinline__ T operator()(const T x) const { - return x / (one + abs(x)); - } -}; - template struct CudaSoftsignGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); @@ -762,27 +642,6 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaMishFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - // mish(x) = x * tanh(softplus(x)) - // softplus(x) = x, if x > threshold - // = ln(1 + exp(x)), otherwise - // Inputs: args[0], the input x - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); - return static_cast(x * tanh(sp)); - } -}; - template struct CudaMishGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -1292,11 +1151,6 @@ REGISTER_OP_CUDA_KERNEL( /* ========================== logit register ============================ */ namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - logit, ops::LogitKernel, - ops::LogitKernel, - ops::LogitKernel); REGISTER_OP_CUDA_KERNEL( logit_grad, ops::LogitGradKernel, diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 5ac39953462..0ca2529f132 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -279,6 +279,46 @@ struct SelectedRowsAddToTensor { } }; +template +struct SelectedRowsAddToTensor { + void operator()(const phi::CPUContext& context, + const phi::SelectedRows& input1, framework::Tensor* input2) { + if (UNLIKELY(input1.rows().size() == 0)) { + LOG(WARNING) << "input selected rows is empty!"; + return; + } + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ( + in1_height, in2_dims[0], + platform::errors::InvalidArgument("The two inputs height must be equal." + "But recieved first input height = " + "[%d], second input height = [%d]", + in1_height, in2_dims[0])); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ( + in1_row_numel, input2->numel() / in1_height, + platform::errors::InvalidArgument( + "The two inputs width must be equal." + "But recieved first input width = [%d], second input width = [%d]", + in1_row_numel, input2->numel() / in1_height)); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->data(); + + for (size_t i = 0; i < in1_rows.size(); i++) { + for (int64_t j = 0; j < in1_row_numel; j++) { + input2_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + } + } + } +}; + template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; @@ -286,6 +326,11 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; // This is a separated namespace for manipulate SelectedRows typed // data. Like merge duplicated rows, adding two SelectedRows etc. // @@ -294,30 +339,30 @@ template struct SelectedRowsAddToTensor +template typename std::enable_if::value>::type elementwise_add_to( - phi::funcs::BlasT* blas, size_t data_len, - const T* in, T* out) { + phi::funcs::BlasT* blas, size_t data_len, const T* in, + T* out) { blas->AXPY(data_len, T(1.f), in, out); } -template +template typename std::enable_if::value>::type elementwise_add_to( - phi::funcs::BlasT* blas, size_t data_len, - const T* in, T* out) { + phi::funcs::BlasT* blas, size_t data_len, const T* in, + T* out) { for (size_t i = 0; i < data_len; i++) { out[i] += in[i]; } } -template +template typename std::enable_if::value>::type add_sparse_inputs(const std::vector& inputs, const std::unordered_map& rows_to_id, - int64_t input_width, - const platform::CPUDeviceContext& context, T* out_data) { + int64_t input_width, const DeviceContext& context, + T* out_data) { #ifndef PADDLE_WITH_MKLDNN - auto blas = phi::funcs::GetBlas(context); + auto blas = phi::funcs::GetBlas(context); #endif for (auto* input : inputs) { if (input->rows().size() == 0) { @@ -336,22 +381,22 @@ add_sparse_inputs(const std::vector& inputs, #else for (size_t i = 0; i < input_rows.size(); i++) { size_t out_i = rows_to_id.at(input_rows[i]); - elementwise_add_to(&blas, static_cast(input_width), - &input_data[i * input_width], - &out_data[out_i * input_width]); + elementwise_add_to( + &blas, static_cast(input_width), &input_data[i * input_width], + &out_data[out_i * input_width]); } #endif } } -template +template typename std::enable_if::value>::type add_sparse_inputs(const std::vector& inputs, const std::unordered_map& rows_to_id, - int64_t input_width, - const platform::CPUDeviceContext& context, T* out_data) { + int64_t input_width, const DeviceContext& context, + T* out_data) { VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name(); - auto blas = phi::funcs::GetBlas(context); + auto blas = phi::funcs::GetBlas(context); for (auto* input : inputs) { if (input->rows().size() == 0) { continue; @@ -361,16 +406,16 @@ add_sparse_inputs(const std::vector& inputs, for (size_t i = 0; i < input_rows.size(); i++) { size_t out_i = rows_to_id.at(input_rows[i]); - elementwise_add_to(&blas, static_cast(input_width), - &input_data[i * input_width], - &out_data[out_i * input_width]); + elementwise_add_to( + &blas, static_cast(input_width), &input_data[i * input_width], + &out_data[out_i * input_width]); } } } -template -struct MergeAdd { - phi::SelectedRows operator()(const platform::CPUDeviceContext& context, +template +struct MergeAddImpl { + phi::SelectedRows operator()(const DeviceContext& context, const phi::SelectedRows& input, const bool sorted_result = false) { phi::SelectedRows out; @@ -378,15 +423,14 @@ struct MergeAdd { return out; } - void operator()(const platform::CPUDeviceContext& context, - const phi::SelectedRows& input, phi::SelectedRows* output, - const bool sorted_result = false) { + void operator()(const DeviceContext& context, const phi::SelectedRows& input, + phi::SelectedRows* output, const bool sorted_result = false) { std::vector inputs; inputs.push_back(&input); (*this)(context, inputs, output, sorted_result); } - void operator()(const platform::CPUDeviceContext& context, + void operator()(const DeviceContext& context, const std::vector& inputs, phi::SelectedRows* output, const bool sorted_result = false) { if (inputs.size() == 0) { @@ -461,7 +505,7 @@ struct MergeAdd { out.set_rows(merge_rows); - phi::funcs::SetConstant constant_functor; + phi::funcs::SetConstant constant_functor; constant_functor(context, out.mutable_value(), static_cast(0.f)); std::unordered_map rows_to_id; @@ -469,11 +513,75 @@ struct MergeAdd { rows_to_id[merge_rows[i]] = i; } - add_sparse_inputs(inputs, rows_to_id, input_width, context, out_data); + add_sparse_inputs(inputs, rows_to_id, input_width, + context, out_data); } } }; +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + phi::SelectedRows operator()(const platform::CPUDeviceContext& context, + const phi::SelectedRows& input, + const bool sorted_result) { + return MergeAddImpl()(context, input, + sorted_result); + } + + void operator()(const platform::CPUDeviceContext& context, + const phi::SelectedRows& input, phi::SelectedRows* output, + const bool sorted_result) { + MergeAddImpl()(context, input, output, + sorted_result); + } + + void operator()(const platform::CPUDeviceContext& context, + const std::vector& inputs, + phi::SelectedRows* output, const bool sorted_result) { + MergeAddImpl()(context, inputs, output, + sorted_result); + } +}; + +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + phi::SelectedRows operator()(const phi::CPUContext& context, + const phi::SelectedRows& input, + const bool sorted_result) { + return MergeAddImpl()(context, input, sorted_result); + } + + void operator()(const phi::CPUContext& context, + const phi::SelectedRows& input, phi::SelectedRows* output, + const bool sorted_result) { + MergeAddImpl()(context, input, output, sorted_result); + } + + void operator()(const phi::CPUContext& context, + const std::vector& inputs, + phi::SelectedRows* output, const bool sorted_result) { + MergeAddImpl()(context, inputs, output, sorted_result); + } +}; + +#define TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(dtype) \ + template struct MergeAddImpl; \ + template struct MergeAddImpl; \ + template struct MergeAdd; \ + template struct MergeAdd; + +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(float) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(double) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int64_t) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::bfloat16) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex) + #ifdef PADDLE_WITH_XPU template struct MergeAdd { @@ -714,17 +822,6 @@ struct MergeAverage { } }; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd>; -template struct MergeAdd>; -template struct MergeAdd; - #ifdef PADDLE_WITH_XPU template struct MergeAdd; #endif diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index a4678550cf7..542d4c97843 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -174,12 +174,77 @@ struct SelectedRowsAddTensor { } }; +template +struct SelectedRowsAddTensor { + void operator()(const phi::GPUContext& context, + const phi::SelectedRows& input1, + const framework::Tensor& input2, framework::Tensor* output) { + auto in1_height = input1.height(); + auto in2_dims = input2.dims(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ( + in1_height, in2_dims[0], + platform::errors::InvalidArgument( + "The two inputs height must be equal." + "But recieved first input height = [%d], first input height = [%d]", + in1_height, in2_dims[0])); + PADDLE_ENFORCE_EQ( + in1_height, out_dims[0], + platform::errors::InvalidArgument( + "The input and output height must be equal." + "But recieved input height = [%d], output height = [%d]", + in1_height, out_dims[0])); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ( + in1_row_numel, input2.numel() / in1_height, + platform::errors::InvalidArgument( + "The two inputs width must be equal." + "But recieved first input width = [%d], second input width = [%d]", + in1_row_numel, input2.numel() / in1_height)); + PADDLE_ENFORCE_EQ( + in1_row_numel, output->numel() / in1_height, + platform::errors::InvalidArgument( + "The input and output width must be equal." + "But recieved input width = [%d], output width = [%d]", + in1_row_numel, output->numel() / in1_height)); + + auto* in1_data = in1_value.data(); + auto* in2_data = input2.data(); + auto* out_data = output->data(); + + phi::funcs::SetConstant functor; + functor(context, output, static_cast(0)); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(in1_rows.size(), 1); + paddle::framework::MixVector mixv_in1_rows(&in1_rows); + SelectedRowsAddTensorKernel< + T, block_size><<>>( + in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), out_data, + in1_row_numel); + + auto out_eigen = framework::EigenVector::Flatten(*output); + auto in2_eigen = framework::EigenVector::Flatten(input2); + out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen; + } +}; + template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; template struct SelectedRowsAdd; template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; +template struct SelectedRowsAdd; +template struct SelectedRowsAddTensor; + template struct SelectedRowsAddTo { void operator()(const platform::CUDADeviceContext& context, @@ -285,12 +350,54 @@ struct SelectedRowsAddToTensor { } }; +template +struct SelectedRowsAddToTensor { + void operator()(const phi::GPUContext& context, + const phi::SelectedRows& input1, framework::Tensor* input2) { + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ( + in1_height, in2_dims[0], + platform::errors::InvalidArgument("The two inputs height must be equal." + "But recieved first input height = " + "[%d], second input height = [%d]", + in1_height, in2_dims[0])); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ( + in1_row_numel, input2->numel() / in1_height, + platform::errors::InvalidArgument( + "The two inputs width must be equal." + "But recieved first input width = [%d], second input width = [%d]", + in1_row_numel, input2->numel() / in1_height)); + + auto* in1_data = in1_value.data(); + auto* in2_data = input2->data(); + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(in1_rows.size(), 1); + paddle::framework::MixVector mixv_in1_rows(&in1_rows); + SelectedRowsAddToTensorKernel< + T, block_size><<>>( + in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), in2_data, + in1_row_numel); + } +}; + template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; namespace scatter { @@ -319,9 +426,9 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, } } -template -struct MergeAdd { - phi::SelectedRows operator()(const platform::CUDADeviceContext& context, +template +struct MergeAddImpl { + phi::SelectedRows operator()(const DeviceContext& context, const phi::SelectedRows& input, const bool sorted_result = false) { phi::SelectedRows out; @@ -329,9 +436,8 @@ struct MergeAdd { return out; } - void operator()(const platform::CUDADeviceContext& context, - const phi::SelectedRows& input, phi::SelectedRows* output, - const bool sorted_result = false) { + void operator()(const DeviceContext& context, const phi::SelectedRows& input, + phi::SelectedRows* output, const bool sorted_result = false) { framework::Vector input_rows(input.rows()); if (input_rows.size() == 0) { return; @@ -350,7 +456,7 @@ struct MergeAdd { phi::make_ddim({static_cast(merge_rows.size()), input_width}), context.GetPlace()); - phi::funcs::SetConstant constant_functor; + phi::funcs::SetConstant constant_functor; constant_functor(context, out.mutable_value(), static_cast(0)); auto* out_data = out.mutable_value()->data(); @@ -369,7 +475,7 @@ struct MergeAdd { mix_vector_out.CopyToCPU(); } - void operator()(const platform::CUDADeviceContext& context, + void operator()(const DeviceContext& context, const std::vector& inputs, phi::SelectedRows* output, const bool sorted_result = false) { if (inputs.size() == 0) { @@ -414,7 +520,7 @@ struct MergeAdd { phi::make_ddim({static_cast(merge_rows.size()), input_width}), context.GetPlace()); - phi::funcs::SetConstant constant_functor; + phi::funcs::SetConstant constant_functor; constant_functor(context, out.mutable_value(), static_cast(0)); auto* out_data = out.mutable_value()->data(); @@ -441,15 +547,69 @@ struct MergeAdd { } }; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd>; -template struct MergeAdd>; +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + phi::SelectedRows operator()(const platform::CUDADeviceContext& context, + const phi::SelectedRows& input, + const bool sorted_result) { + return MergeAddImpl()(context, input, + sorted_result); + } + + void operator()(const platform::CUDADeviceContext& context, + const phi::SelectedRows& input, phi::SelectedRows* output, + const bool sorted_result) { + MergeAddImpl()(context, input, output, + sorted_result); + } + + void operator()(const platform::CUDADeviceContext& context, + const std::vector& inputs, + phi::SelectedRows* output, const bool sorted_result) { + MergeAddImpl()(context, inputs, output, + sorted_result); + } +}; + +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + phi::SelectedRows operator()(const phi::GPUContext& context, + const phi::SelectedRows& input, + const bool sorted_result) { + return MergeAddImpl()(context, input, sorted_result); + } + + void operator()(const phi::GPUContext& context, + const phi::SelectedRows& input, phi::SelectedRows* output, + const bool sorted_result) { + MergeAddImpl()(context, input, output, sorted_result); + } + + void operator()(const phi::GPUContext& context, + const std::vector& inputs, + phi::SelectedRows* output, const bool sorted_result) { + MergeAddImpl()(context, inputs, output, sorted_result); + } +}; + +#define TEMPLATE_SPECIALIZED_FOR_MERGEADD(dtype) \ + template struct MergeAddImpl; \ + template struct MergeAddImpl; \ + template struct MergeAdd; \ + template struct MergeAdd; + +TEMPLATE_SPECIALIZED_FOR_MERGEADD(float) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(double) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(int) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(int64_t) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::float16) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::bfloat16) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex) template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index d443b7bb2a0..761d65ed36c 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -11,7 +11,7 @@ set_property(GLOBAL PROPERTY PHI_KERNELS "") # [ 1. Common kernel compilation dependencies ] set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel) -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index a5b737b28c2..afcb87f9b83 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -100,5 +100,6 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh); DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu); DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Tanh); +DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Exp); } // namespace phi diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 885dccad8e3..623f7e467ad 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -37,6 +37,8 @@ DECLARE_ACTIVATION_KERNEL(Acosh) DECLARE_ACTIVATION_KERNEL(Atanh) DECLARE_ACTIVATION_KERNEL(Relu) DECLARE_ACTIVATION_KERNEL(Tanh) +DECLARE_ACTIVATION_KERNEL(Exp) +DECLARE_ACTIVATION_KERNEL(Expm1) template void BReluKernel(const Context& dev_ctx, @@ -57,4 +59,16 @@ void ThresholdedReluKernel(const Context& dev_ctx, float threshold, DenseTensor* out); +template +void LogitKernel(const Context& dev_ctx, + const DenseTensor& x, + float eps, + DenseTensor* out); + +template +void MishKernel(const Context& dev_ctx, + const DenseTensor& x, + float threshold, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index f9af50f6832..d454700076d 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -104,6 +104,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::ReluGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Tanh, funcs::TanhGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Exp, funcs::ExpGradFunctor); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu, funcs::LeakyReluGradFunctor, @@ -159,3 +160,12 @@ PD_REGISTER_KERNEL(tanh_triple_grad, float, double, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(exp_grad, + CPU, + ALL_LAYOUT, + phi::ExpGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 0d13429c8f6..ecbab531232 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/impl/activation_impl.h" namespace phi { @@ -67,11 +68,27 @@ DEFINE_CPU_ACTIVATION_KERNEL(Acosh, funcs::AcoshFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Atanh, funcs::AtanhFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Relu, funcs::ReluCPUFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Tanh, funcs::TanhFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Exp, funcs::ExpFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Expm1, funcs::Expm1Functor) +DEFINE_CPU_ACTIVATION_KERNEL(Reciprocal, funcs::ReciprocalFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Square, funcs::SquareFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Sqrt, funcs::SqrtFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Rsqrt, funcs::RsqrtFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Softsign, funcs::SoftsignFunctor) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, funcs::LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, funcs::ThresholdedReluFunctor, threshold) +DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, funcs::MishFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, funcs::BReluFunctor, t_min, t_max) +DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, + funcs::STanhFunctor, + scale_a, + scale_b) +DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, + funcs::SoftplusFunctor, + beta, + threshold) } // namespace phi PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} @@ -94,3 +111,23 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, Tanh) PD_REGISTER_ACTIVATION_KERNEL(brelu, BRelu) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedRelu) +PD_REGISTER_ACTIVATION_KERNEL(mish, Mish) +PD_REGISTER_ACTIVATION_KERNEL(stanh, STanh) +PD_REGISTER_ACTIVATION_KERNEL(reciprocal, Reciprocal) +PD_REGISTER_ACTIVATION_KERNEL(sqrt, Sqrt) +PD_REGISTER_ACTIVATION_KERNEL(rsqrt, Rsqrt) +PD_REGISTER_ACTIVATION_KERNEL(softplus, Softplus) +PD_REGISTER_ACTIVATION_KERNEL(softsign, Softsign) + +PD_REGISTER_KERNEL( + exp, CPU, ALL_LAYOUT, phi::ExpKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(expm1, + CPU, + ALL_LAYOUT, + phi::Expm1Kernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {} +PD_REGISTER_KERNEL( + square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index c8fb54bb102..5b2e70ceb54 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -100,6 +100,15 @@ struct SinFunctor : public BaseActivationFunctor { } }; +// reciprocal(x) = 1 / x +template +struct ReciprocalFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = static_cast(1) / x; + } +}; + // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { @@ -124,6 +133,57 @@ struct CosFunctor : public BaseActivationFunctor { } }; +template +struct LogitFunctor { + template + void operator()(Device d, X x, Out out, P p, float eps) const { + // logit(x) = ln(x/(1-x)) + auto tmp_x = + (x.cwiseMin(static_cast(1.0 - eps))).cwiseMax(static_cast(eps)); + + if (!eps) { + out.device(d) = (x < static_cast(0.0) || x > static_cast(1.0)) + .select(p.constant(static_cast(NAN)), + (tmp_x / (static_cast(1) - tmp_x)).log()); + } else { + out.device(d) = (tmp_x / (static_cast(1) - tmp_x)).log(); + } + } +}; + +// mish(x) = x * tanh(softplus(x)) +// softplus(x) = x, if x > threshold +// = ln(1 + exp(x)), otherwise +template +struct MishFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Out out) { + auto sp = (x > static_cast(threshold)) + .select(x, (static_cast(1) + x.exp()).log()); + out.device(d) = x * sp.tanh(); + } +}; + +template +struct STanhFunctor : public BaseActivationFunctor { + float scale_a; + float scale_b; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + template + void operator()(Device d, X x, Out out) const { + out.device(d) = + static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); + } +}; + template struct Tangent { HOSTDEVICE T operator()(const T& val) const { return tan(val); } @@ -151,6 +211,55 @@ struct TanGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +// square(x) = x^2 +template +struct SquareFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.square(); + } +}; + +// sqrt(x) = x^(1/2) +template +struct SqrtFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.sqrt(); + } +}; + +// rsqrt(x) = x^(-1/2) +template +struct RsqrtFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.rsqrt(); + } +}; + +// For numerical stability, using the following formula instead of softplus(x) = +// log(1 + exp(x)) +// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta = +// 1, threshold = 20 by default), otherwise x +template +struct SoftplusFunctor : public BaseActivationFunctor { + float beta; + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Out out) { + auto x_beta = static_cast(beta) * x; + out.device(d) = (x_beta > static_cast(threshold)) + .select(x, + (static_cast(1) + x_beta.exp()).log() / + static_cast(beta)); + } +}; + // Tangent(x) = tan(x) template struct TanFunctor : public BaseActivationFunctor { @@ -452,6 +561,41 @@ struct AtanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +// exp functor +// exp(x) = e^x +template +struct ExpFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.exp(); + } +}; + +template +struct ExpGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +// expm1(x) = e^x - 1 +template +struct Expm1Functor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.expm1(); + } +}; + // relu(x) = max(x, 0) template struct ReluCPUFunctor : public BaseActivationFunctor { @@ -672,6 +816,15 @@ struct BReluGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +// softsign(x) = x / (1 + |x|) +template +struct SoftsignFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) { + out.device(d) = x / (static_cast(1) + x.abs()); + } +}; + template struct LeakyReluFunctor : public BaseActivationFunctor { float alpha; @@ -827,6 +980,54 @@ struct CudaCosGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaExpFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // exp(x) = exp(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(exp(x)); + } +}; + +template +struct CudaSquareFunctor : public BaseActivationFunctor { + // square(x) = x * x + __device__ __forceinline__ T operator()(const T x) const { return x * x; } +}; + +template +struct CudaExpGradFunctor : public BaseActivationFunctor { + // dx = dout * out + __device__ __forceinline__ T operator()(const T dout, const T out) const { + return dout * out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +template +struct CudaReciprocalFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // reciprocal(x) = 1 / x + __device__ __forceinline__ T operator()(const T x) const { return one / x; } +}; + +template +struct CudaExpm1Functor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // expm1(x) = expm1(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(expm1(x)); + } +}; + template struct CudaSinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -838,6 +1039,16 @@ struct CudaSinFunctor : public BaseActivationFunctor { } }; +template +struct CudaSoftsignFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // softsign(x) = x / (1 + abs(x)) + __device__ __forceinline__ T operator()(const T x) const { + return x / (one + abs(x)); + } +}; + template struct CudaSinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -1049,6 +1260,46 @@ struct CudaAtanhFunctor : public BaseActivationFunctor { } }; +template +struct CudaSTanhFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + float scale_a; + float scale_b; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + // stanh(x) = b * tanh(a * x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + MPType a = static_cast(scale_a); + MPType b = static_cast(scale_b); + return static_cast(b * tanh(a * x)); + } +}; + +template +struct CudaSoftplusFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta; + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + MPType b = static_cast(beta); + MPType t = static_cast(threshold); + MPType x_beta = x * beta; + return static_cast(x_beta > t ? x : log(one + exp(x_beta)) / b); + } +}; + template struct CudaAtanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -1064,6 +1315,28 @@ struct CudaAtanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSqrtFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // sqrt(x) = sqrt(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(sqrt(x)); + } +}; + +template +struct CudaRsqrtFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // rsqrt(x) = rsqrt(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(rsqrt(x)); + } +}; + template struct CudaAtanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -1131,6 +1404,27 @@ struct CudaBReluFunctor : public BaseActivationFunctor { } }; +template +struct CudaMishFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // mish(x) = x * tanh(softplus(x)) + // softplus(x) = x, if x > threshold + // = ln(1 + exp(x)), otherwise + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); + return static_cast(x * tanh(sp)); + } +}; + template struct CudaBReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 00792b8ab60..40cea77722d 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -155,6 +155,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, CudaCoshGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, CudaAsinhGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, CudaAcoshGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, CudaAtanhGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Exp, CudaExpGradFunctor); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu, CudaLeakyReluGradFunctor, @@ -234,3 +235,12 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad, LeakyReluDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, ThresholdedReluGradKernel) + +PD_REGISTER_KERNEL(exp_grad, + GPU, + ALL_LAYOUT, + phi::ExpGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 3c340a89f57..a93b4daa95d 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/impl/activation_grad_impl.h" +#include "paddle/phi/kernels/impl/activation_impl.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" @@ -88,13 +89,27 @@ DEFINE_GPU_ACTIVATION_KERNEL(Acosh, funcs::CudaAcoshFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Atanh, funcs::CudaAtanhFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Relu, funcs::CudaReluFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Tanh, funcs::CudaTanhFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Exp, funcs::CudaExpFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Expm1, funcs::CudaExpm1Functor) +DEFINE_GPU_ACTIVATION_KERNEL(Reciprocal, funcs::CudaReciprocalFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Square, funcs::CudaSquareFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Sqrt, funcs::CudaSqrtFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Rsqrt, funcs::CudaRsqrtFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Softsign, funcs::CudaSoftsignFunctor) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, CudaThresholdedReluFunctor, threshold) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold) + DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max) +DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Stanh, CudaSTanhFunctor, scale_a, scale_b) +DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, + CudaSoftplusFunctor, + beta, + threshold) } // namespace phi @@ -142,3 +157,23 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) +PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) +PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel) +PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) +PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) +PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) +PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) +PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) + +PD_REGISTER_KERNEL( + exp, GPU, ALL_LAYOUT, phi::ExpKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(expm1, + GPU, + ALL_LAYOUT, + phi::Expm1Kernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(logit, GPU, ALL_LAYOUT, phi::LogitKernel, float, double) {} +PD_REGISTER_KERNEL( + square, GPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/clip_by_norm_kernel.cu b/paddle/phi/kernels/gpu/clip_by_norm_kernel.cu index 74efdcf7334..d9f3e247ef9 100644 --- a/paddle/phi/kernels/gpu/clip_by_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_by_norm_kernel.cu @@ -40,16 +40,16 @@ void ClipByNormKernel( DenseTensor tmp; tmp.Resize({1}); dev_ctx.template Alloc(&tmp); - kernels::TensorReduceImpl>( + + phi::funcs::ReduceKernel>( dev_ctx, x_in, &tmp, kps::SquareFunctor(), - reduce_dims, - dev_ctx.stream()); + reduce_dims); auto tmp_eigen = EigenVector::Flatten(tmp); auto x_norm = tmp_eigen.sqrt(); diff --git a/paddle/phi/kernels/impl/activation_impl.h b/paddle/phi/kernels/impl/activation_impl.h index ca3debd394a..05339ceb748 100644 --- a/paddle/phi/kernels/impl/activation_impl.h +++ b/paddle/phi/kernels/impl/activation_impl.h @@ -47,4 +47,20 @@ void ActivationImpl(const Context& dev_ctx, } } +template +void LogitKernel(const Context& dev_ctx, + const DenseTensor& x, + float eps, + DenseTensor* out) { + dev_ctx.template Alloc(out); + + auto eigen_out = EigenVector::Flatten(*out); + auto eigen_in = EigenVector::Flatten(x); + auto& place = *dev_ctx.eigen_device(); + auto eigen_p = EigenVector::Flatten(*out); + + funcs::LogitFunctor functor; + functor(place, eigen_in, eigen_out, eigen_p, eps); +} + } // namespace phi -- GitLab