/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/platform/cuda_device_function.h" namespace paddle { namespace operators { template struct CudaReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // relu(x) = max(x, 0) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] > zero ? args[0] : zero; } }; template struct CudaReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // dx = dout * (out > 0) // Inputs: args[0], the input dout // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return args[1] > zero ? args[0] : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct CudaLeakyReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // leakyrelu(x) = x > 0 ? x : alpha * x // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] > zero ? args[0] : static_cast(alpha) * args[0]; } }; template struct CudaLeakyReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // dx = dout * (x > 0 ? 1 : alpha) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[1] > zero ? args[0] : static_cast(alpha) * args[0]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaSigmoidFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); // sigmoid(x) = 1 / (1 + exp(-x)) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(one / (one + exp(-x))); } }; template struct CudaSigmoidGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout * out * (1 - out) // Inputs: args[0], the input dout // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return args[0] * args[1] * (one - args[1]); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct CudaSiluFunctor : public BaseActivationFunctor { // MPType means Compute Type using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); // silu(x) = x / (1 + exp(-x)) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(x / (one + exp(-x))); } }; template struct CudaSiluGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); MPType temp = one / (one + exp(-x)); return static_cast(dout * (temp * (one + x * (one - temp)))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaLogSigmoidFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType zero = static_cast(0.0f); // logsigmoid(x) = log(1 / (1 + exp(-x))) // For numerical stability, // logsigmoid(x) = // - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); MPType temp = x > zero ? zero : -x; return static_cast(-temp - log(exp(-temp) + exp(-x - temp))); } }; template struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType zero = static_cast(0.0f); // dx = dout * exp(-x) / (1 + exp(-x)) // For numerical stability: // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x, // 0))) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); MPType temp1 = x > zero ? zero : -x; MPType temp2 = exp(-x - temp1); return static_cast(dout * (temp2 / (exp(-temp1) + temp2))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaAtanFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // atan(x) = atan(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(atan(x)); } }; template struct CudaAtanGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout / (1 + x^2) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] / (one + args[1] * args[1]); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaSoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } // softshrink(x) = x - lambda, if x > lambda; // x + lambda, if x < -lambda; // 0, otherwise. // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { T x = args[0]; T l = static_cast(lambda); T temp1 = static_cast(x > l); T temp2 = static_cast(x < -l); return temp1 * (x - l) + temp2 * (x + l); } }; template struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } // dx = dout, if x > lambda or x < -lambda else 0 // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { T x = args[1]; T l = static_cast(lambda); return (x >= -l && x <= l) ? zero : args[0]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaCeilFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // ceil(x) = ceil(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(ceil(x)); } }; template struct CudaFloorFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // floor(x) = floor(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(floor(x)); } }; template struct CudaRoundFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // round(x) = round(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(round(x)); } }; // grad functor for ceil, floor and round template struct CudaZeroGradFunctor : public BaseActivationFunctor { __device__ __forceinline__ T operator()(const T* args) const { return static_cast(0.0f); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } }; template struct CudaCosFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // cos(x) = cos(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(cos(x)); } }; template struct CudaCosGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // dx = dout * (-sin(x)) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); return static_cast(-dout * sin(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaSinFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // sin(x) = sin(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(sin(x)); } }; template struct CudaSinGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // dx = dout * cos(x) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); return static_cast(dout * cos(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaTanFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // tan(x) = tan(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(tan(x)); } }; template struct CudaTanGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // dx = dout / cos(x)^2 // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); return static_cast(dout / (cos(x) * cos(x))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaAsinFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // asin(x) = asin(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(asin(x)); } }; template struct CudaAsinGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout / sqrt(1 - x^2) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); return static_cast(dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaAcosFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // acos(x) = acos(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(acos(x)); } }; template struct CudaAcosGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = -dout / sqrt(1 - x^2) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); return static_cast(-dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaCoshFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // cosh(x) = cosh(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(cosh(x)); } }; template struct CudaCoshGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // dx = dout * sinh(x) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); return static_cast(dout * sinh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaSinhFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // sinh(x) = sinh(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(sinh(x)); } }; template struct CudaSinhGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // dx = dout * cosh(x) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); return static_cast(dout * cosh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaTanhFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // tanh(x) = tanh(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(tanh(x)); } }; template struct CudaTanhGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout * (1 - out^2) // Inputs: args[0], the input dout // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { T dout = static_cast(args[0]); T out = static_cast(args[1]); return dout * (one - out * out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct CudaReciprocalFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // reciprocal(x) = 1 / x // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { return one / args[0]; } }; template struct CudaReciprocalGradFunctor : public BaseActivationFunctor { // dx = -dout * out^2 // Inputs: args[0], the input dout // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return -args[0] * args[1] * args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct CudaExpFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // exp(x) = exp(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(exp(x)); } }; template struct CudaExpGradFunctor : public BaseActivationFunctor { // dx = dout * out // Inputs: args[0], the input dout // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return args[0] * args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct CudaLogFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // log(x) = log(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(log(x)); } }; template struct CudaLogGradFunctor : public BaseActivationFunctor { // dx = dout / x // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] / args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaSquareFunctor : public BaseActivationFunctor { // square(x) = x * x // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] * args[0]; } }; template struct CudaSquareGradFunctor : public BaseActivationFunctor { T two = static_cast(2.0f); // dx = dout * 2 * x // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] * two * args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct CudaSqrtFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // sqrt(x) = sqrt(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(sqrt(x)); } }; template struct CudaSqrtGradFunctor : public BaseActivationFunctor { T one_half = static_cast(0.5f); // dx = dout * 0.5 / out // Inputs: args[0], the input dout // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return one_half * args[0] / args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct CudaRsqrtFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // rsqrt(x) = rsqrt(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast(args[0]); return static_cast(rsqrt(x)); } }; template struct CudaRsqrtGradFunctor : public BaseActivationFunctor { T minus_one_half = static_cast(-0.5f); // dx = dout * -0.5 / out^3 // Inputs: args[0], the input dout // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { T out = args[1]; return minus_one_half * args[0] * out * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template class ActivationCudaKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor* x = nullptr; framework::Tensor* out = nullptr; ExtractActivationTensor(ctx, &x, &out); out->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); std::vector ins = {x}; std::vector outs = {out}; auto functor = Functor(); auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, functor); } }; template class ActivationGradCudaKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *x, *out, *d_out; framework::Tensor* d_x = nullptr; x = out = d_out = nullptr; ExtractActivationGradTensor(ctx, &x, &out, &d_out, &d_x); d_x->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); auto functor = Functor(); auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } std::vector ins = {d_out}; std::vector outs = {d_x}; if (static_cast(Functor::FwdDeps()) == static_cast(kDepOut)) { // Only need forward output Out ins.push_back(out); LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, functor); } else if (static_cast(Functor::FwdDeps()) == static_cast(kDepX)) { // Only need forward input X ins.push_back(x); LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, functor); } else { LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, functor); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; #define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \ grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ act_type, ops::ActivationKernel>, \ ops::ActivationKernel>, \ ops::ActivationKernel>); \ REGISTER_OP_CUDA_KERNEL( \ act_type##_grad, ops::ActivationGradKernel>, \ ops::ActivationGradKernel>, \ ops::ActivationGradKernel>); #define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ act_type, ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>); \ REGISTER_OP_CUDA_KERNEL( \ act_type##_grad, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>); /* ======================== leaky relu register ============================ */ REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor, CudaLeakyReluGradFunctor); REGISTER_OP_CUDA_KERNEL( leaky_relu_grad_grad, ops::ActivationDoubleGradKernel>, ops::ActivationDoubleGradKernel>, ops::ActivationDoubleGradKernel< plat::CUDADeviceContext, ops::LeakyReluGradGradFunctor>); /* ========================================================================== */ /* ======================== elu register ============================ */ REGISTER_ACTIVATION_GPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor); REGISTER_OP_CUDA_KERNEL( elu_grad_grad, ops::ELUDoubleGradKernel>, ops::ELUDoubleGradKernel>, ops::ELUDoubleGradKernel>); /* ========================================================================== */ /* =========================== relu register ============================ */ REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor, CudaReluGradFunctor); REGISTER_OP_CUDA_KERNEL( relu_grad_grad, ops::ActivationDoubleGradKernel>, ops::ActivationDoubleGradKernel>, ops::ActivationDoubleGradKernel>); /* ========================================================================== */ /* =========================== tanh register ============================ */ REGISTER_ACTIVATION_CUDA_KERNEL(tanh, Tanh, CudaTanhFunctor, CudaTanhGradFunctor); REGISTER_OP_CUDA_KERNEL( tanh_grad_grad, ops::TanhDoubleGradKernel>, ops::TanhDoubleGradKernel>, ops::TanhDoubleGradKernel>); /* ========================================================================== */ /* =========================== sqrt register ============================= */ REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor, CudaSqrtGradFunctor); REGISTER_OP_CUDA_KERNEL( sqrt_grad_grad, ops::SqrtDoubleGradKernel>, ops::SqrtDoubleGradKernel>, ops::SqrtDoubleGradKernel>); /* ========================================================================== */ /* =========================== rsqrt register ============================= */ REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor, CudaRsqrtGradFunctor); REGISTER_OP_CUDA_KERNEL( rsqrt_grad_grad, ops::RsqrtDoubleGradKernel>, ops::RsqrtDoubleGradKernel>, ops::RsqrtDoubleGradKernel>); /* ========================================================================== */ /* =========================== square register ============================ */ REGISTER_OP_CUDA_KERNEL( square, ops::ActivationCudaKernel>, ops::ActivationCudaKernel>, ops::ActivationCudaKernel>, ops::ActivationCudaKernel>, ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( square_grad, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>); REGISTER_OP_CUDA_KERNEL( square_grad_grad, ops::SquareDoubleGradKernel>, ops::SquareDoubleGradKernel>, ops::SquareDoubleGradKernel>, ops::SquareDoubleGradKernel>, ops::SquareDoubleGradKernel>); /* ========================================================================== */ /* ========================== pow register ============================ */ REGISTER_OP_CUDA_KERNEL( pow, ops::PowKernel>, ops::PowKernel>, ops::PowKernel>, ops::PowKernel>, ops::PowKernel>); REGISTER_OP_CUDA_KERNEL( pow_grad, ops::PowGradKernel>, ops::PowGradKernel>, ops::PowGradKernel>, ops::PowGradKernel>, ops::PowGradKernel>); /* ========================================================================== */ /* ========================== exp register ============================ */ REGISTER_OP_CUDA_KERNEL( exp, ops::ActivationCudaKernel>, ops::ActivationCudaKernel>, ops::ActivationKernel>, ops::ActivationKernel>, ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( exp_grad, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>); /* ========================================================================== */ /* ========================== Log register ==================================*/ REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor); REGISTER_OP_CUDA_KERNEL( log_grad_grad, ops::LogDoubleGradKernel>, ops::LogDoubleGradKernel>, ops::LogDoubleGradKernel>); /* ========================================================================== */ REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, CudaSigmoidGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, CudaLogSigmoidGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(atan, Atan, CudaAtanFunctor, CudaAtanGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(softshrink, SoftShrink, CudaSoftShrinkFunctor, CudaSoftShrinkGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(acos, Acos, CudaAcosFunctor, CudaAcosGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor, CudaReciprocalGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(log10, Log10, Log10Functor, Log10GradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(brelu, BRelu, BReluFunctor, BReluGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(stanh, STanh, STanhFunctor, STanhGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(relu6, Relu6, Relu6Functor, Relu6GradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(swish, Swish, SwishFunctor, SwishGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, ThresholdedReluGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor);