From eca8dcc7a3d95d970d960a0e6f1631ca448324c1 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Tue, 27 Apr 2021 15:01:07 +0800 Subject: [PATCH] Unify the implementation of activation operation (#32348) --- paddle/fluid/operators/activation_op.cu | 1112 +++++++++++++++-------- paddle/fluid/operators/activation_op.h | 4 +- 2 files changed, 759 insertions(+), 357 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 781a97c1ff..836c5fa06f 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -10,382 +10,719 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/platform/cuda_device_function.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using float16 = paddle::platform::float16; +template +struct CudaReluFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + + // relu(x) = max(x, 0) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] > zero ? args[0] : zero; + } +}; template -struct CudaVecType { - using type = T; - static constexpr int vecsize = 1; +struct CudaReluGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + + // dx = dout * (out > 0) + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + return args[1] > zero ? args[0] : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -template <> -struct CudaVecType { - using type = __half2; - static constexpr int vecsize = 2; +template +struct CudaLeakyReluFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // leakyrelu(x) = x > 0 ? x : alpha * x + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] > zero ? args[0] : static_cast(alpha) * args[0]; + } }; -template <> -struct CudaVecType { - using type = float4; - static constexpr int vecsize = 4; +template +struct CudaLeakyReluGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // dx = dout * (x > 0 ? 1 : alpha) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[1] > zero ? args[0] : static_cast(alpha) * args[0]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template -class BaseGPUFunctor { - public: - using ELEMENT_TYPE = T; +struct CudaSigmoidFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // sigmoid(x) = 1 / (1 + exp(-x)) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(one / (one + exp(-x))); + } +}; - using AttrPair = std::vector>; +template +struct CudaSigmoidGradFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // dx = dout * out * (1 - out) + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * args[1] * (one - args[1]); + } - AttrPair GetAttrs() { return AttrPair(); } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -/* ========================================================================== */ +template +struct CudaSiluFunctor : public BaseActivationFunctor { + // MPType means Compute Type + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // silu(x) = x / (1 + exp(-x)) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(x / (one + exp(-x))); + } +}; -/* =========================== relu forward ============================ */ template -class ReluGPUFunctor : public BaseGPUFunctor { - private: - T zero_; +struct CudaSiluGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + MPType temp = one / (one + exp(-x)); + return static_cast(dout * (temp * (one + x * (one - temp)))); + } - public: - ReluGPUFunctor() { zero_ = static_cast(0.0f); } - - // for relu forward when T is double - __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type in) { - // relu forward : out = max(x, 0) - return in > zero_ ? in : zero_; - } - - // when num % vecsize != 0 this func will be used - __device__ __forceinline__ T ComputeRemainder(const T in) { - // relu forward : out = max(x, 0) - return in > zero_ ? in : zero_; - } -}; - -template <> -__device__ __forceinline__ CudaVecType::type -ReluGPUFunctor::Compute(const CudaVecType::type in) { - // relu forward : out = max(in, 0) - return make_float4((in.x > zero_) * (in.x), (in.y > zero_) * (in.y), - (in.z > zero_) * (in.z), (in.w > zero_) * (in.w)); -} - -template <> -__device__ __forceinline__ CudaVecType::type -ReluGPUFunctor::Compute(const CudaVecType::type in) { -// relu forward : out = max(in, 0) -#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - const half2 kzero = __float2half2_rn(0.0f); - return __hmul2(__hgt2(in, kzero), in); -#else - const float2 xx = __half22float2(in); - return __floats2half2_rn((xx.x > 0.0f) * static_cast(xx.x), - (xx.y > 0.0f) * static_cast(xx.y)); -#endif -} -/* ========================================================================== */ + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; -/* =========================== relu backward ============================ - */ +template +struct CudaLogSigmoidFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); + + // logsigmoid(x) = log(1 / (1 + exp(-x))) + // For numerical stability, + // logsigmoid(x) = + // - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + MPType temp = x > zero ? zero : -x; + return static_cast(-temp - log(exp(-temp) + exp(-x - temp))); + } +}; template -class ReluGradGPUFunctor : public BaseGPUFunctor { - private: - T zero_; +struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); + + // dx = dout * exp(-x) / (1 + exp(-x)) + // For numerical stability: + // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x, + // 0))) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + MPType temp1 = x > zero ? zero : -x; + MPType temp2 = exp(-x - temp1); + return static_cast(dout * (temp2 / (exp(-temp1) + temp2))); + } - public: - ReluGradGPUFunctor() { zero_ = static_cast(0.0f); } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaAtanFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // atan(x) = atan(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(atan(x)); + } +}; + +template +struct CudaAtanGradFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // dx = dout / (1 + x^2) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / (one + args[1] * args[1]); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaSoftShrinkFunctor : public BaseActivationFunctor { + float lambda; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + + // softshrink(x) = x - lambda, if x > lambda; + // x + lambda, if x < -lambda; + // 0, otherwise. + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[0]; + T l = static_cast(lambda); + T temp1 = static_cast(x > l); + T temp2 = static_cast(x < -l); + return temp1 * (x - l) + temp2 * (x + l); + } +}; + +template +struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float lambda; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + + // dx = dout, if x > lambda or x < -lambda else 0 + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[1]; + T l = static_cast(lambda); + return (x >= -l && x <= l) ? zero : args[0]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaCeilFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // ceil(x) = ceil(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(ceil(x)); + } +}; + +template +struct CudaFloorFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // floor(x) = floor(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(floor(x)); + } +}; + +template +struct CudaRoundFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // round(x) = round(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(round(x)); + } +}; + +// grad functor for ceil, floor and round +template +struct CudaZeroGradFunctor : public BaseActivationFunctor { + __device__ __forceinline__ T operator()(const T* args) const { + return static_cast(0.0f); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } +}; + +template +struct CudaCosFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // cos(x) = cos(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(cos(x)); + } +}; + +template +struct CudaCosGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // dx = dout * (-sin(x)) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(-dout * sin(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaSinFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // sin(x) = sin(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(sin(x)); + } +}; + +template +struct CudaSinGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // dx = dout * cos(x) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout * cos(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaTanFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // tan(x) = tan(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(tan(x)); + } +}; + +template +struct CudaTanGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // dx = dout / cos(x)^2 + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout / (cos(x) * cos(x))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaAsinFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // asin(x) = asin(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(asin(x)); + } +}; + +template +struct CudaAsinGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // dx = dout / sqrt(1 - x^2) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout / sqrt(one - x * x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaAcosFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // acos(x) = acos(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(acos(x)); + } +}; + +template +struct CudaAcosGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // dx = -dout / sqrt(1 - x^2) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(-dout / sqrt(one - x * x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; - // for relu backward when T is double - __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type out, - const typename CudaVecType::type dout) { - return out > zero_ ? dout : zero_; +template +struct CudaCoshFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // cosh(x) = cosh(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(cosh(x)); + } +}; + +template +struct CudaCoshGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // dx = dout * sinh(x) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout * sinh(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaSinhFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // sinh(x) = sinh(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(sinh(x)); + } +}; + +template +struct CudaSinhGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // dx = dout * cosh(x) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout * cosh(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaTanhFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // tanh(x) = tanh(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(tanh(x)); } +}; - // when num % vecsize != 0 this func will be used - __device__ __forceinline__ T ComputeRemainder(const T out, const T dout) { - // relu backward : dx = out > 0 ? dout : 0 - return out > zero_ ? dout : zero_; +template +struct CudaTanhGradFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // dx = dout * (1 - out^2) + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + T dout = static_cast(args[0]); + T out = static_cast(args[1]); + return dout * (one - out * out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -template <> -__device__ __forceinline__ CudaVecType::type -ReluGradGPUFunctor::Compute(const CudaVecType::type out, - const CudaVecType::type dout) { - // relu backward : dx = out > 0 ? dout : 0; - return make_float4((out.x > zero_) * (dout.x), (out.y > zero_) * (dout.y), - (out.z > zero_) * (dout.z), (out.w > zero_) * (dout.w)); -} - -template <> -__device__ __forceinline__ CudaVecType::type -ReluGradGPUFunctor::Compute(const CudaVecType::type out, - const CudaVecType::type dout) { -// relu backward : dx = out > 0 ? dout : 0; -#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - const half2 kzero = __float2half2_rn(0.0f); - return __hmul2(__hgt2(out, kzero), dout); -#else - const float2 xx = __half22float2(out); - const float2 yy = __half22float2(dout); - return __floats2half2_rn((xx.x > 0.0f) * static_cast(yy.x), - (xx.y > 0.0f) * static_cast(yy.y)); -#endif -} +template +struct CudaReciprocalFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // reciprocal(x) = 1 / x + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return one / args[0]; + } +}; -/* ========================================================================== */ -/* ======================== leaky relu forward ======================== - */ template -class LeakyReluGPUFunctor : public BaseGPUFunctor { - private: - T zero_; - float alpha_; +struct CudaReciprocalGradFunctor : public BaseActivationFunctor { + // dx = -dout * out^2 + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + return -args[0] * args[1] * args[1]; + } - public: - LeakyReluGPUFunctor() { zero_ = static_cast(0.0f); } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha_}}; - } - // leakyrelu forward : out = x > 0 ? x : x * alpha - __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type in) { - return in > zero_ ? in : static_cast(alpha_) * in; - } - - __device__ __forceinline__ T ComputeRemainder(const T in) { - // leakyrelu forward : out = x > 0 ? x : x * alpha - return in > zero_ ? in : static_cast(alpha_) * in; - } -}; - -template <> -__device__ __forceinline__ CudaVecType::type -LeakyReluGPUFunctor::Compute(const CudaVecType::type in) { - // leakyrelu forward : out = x > 0 ? x : x * alpha - return make_float4((in.x > zero_) ? (in.x) : (in.x) * alpha_, - (in.y > zero_) ? (in.y) : (in.y) * alpha_, - (in.z > zero_) ? (in.z) : (in.z) * alpha_, - (in.w > zero_) ? (in.w) : (in.w) * alpha_); -} - -template <> -__device__ __forceinline__ CudaVecType::type -LeakyReluGPUFunctor::Compute(const CudaVecType::type in) { - // leakyrelu forward : out = x > 0 ? x : x * alpha - const float2 xx = __half22float2(in); - return __floats2half2_rn((xx.x > 0.0f) ? xx.x : xx.x * alpha_, - (xx.y > 0.0f) ? xx.y : xx.y * alpha_); -} -/* ========================================================================== */ +template +struct CudaExpFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // exp(x) = exp(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(exp(x)); + } +}; -/* =========================== leaky relu backward ======================= - */ template -class LeakyReluGradGPUFunctor : public BaseGPUFunctor { - private: - T zero_; - float alpha_; +struct CudaExpGradFunctor : public BaseActivationFunctor { + // dx = dout * out + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * args[1]; + } - public: - LeakyReluGradGPUFunctor() { zero_ = static_cast(0.0f); } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha_}}; +template +struct CudaLogFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // log(x) = log(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(log(x)); + } +}; + +template +struct CudaLogGradFunctor : public BaseActivationFunctor { + // dx = dout / x + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / args[1]; } - // for leaky relu backward when T is double - __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type in, - const typename CudaVecType::type dout) { - // leakyrelu backward : dx = x > 0 ? dout : alpha * dout - return in > zero_ ? dout : static_cast(alpha_) * dout; + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaSquareFunctor : public BaseActivationFunctor { + // square(x) = x * x + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * args[0]; } +}; - // when num % vecsize != 0 this func will be used - __device__ __forceinline__ T ComputeRemainder(const T in, const T dout) { - // leakyrelu backward : dx = x > 0 ? dout : alpha * dout - return in > zero_ ? dout : static_cast(alpha_) * dout; +template +struct CudaSquareGradFunctor : public BaseActivationFunctor { + T two = static_cast(2.0f); + + // dx = dout * 2 * x + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * two * args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -template <> -__device__ __forceinline__ CudaVecType::type -LeakyReluGradGPUFunctor::Compute(const CudaVecType::type in, - const CudaVecType::type dout) { - // leakyrelu backward : dx = x > 0 ? dout : alpha * dout - return make_float4((in.x > zero_) ? (dout.x) : alpha_ * (dout.x), - (in.y > zero_) ? (dout.y) : alpha_ * (dout.y), - (in.z > zero_) ? (dout.z) : alpha_ * (dout.z), - (in.w > zero_) ? (dout.w) : alpha_ * (dout.w)); -} - -template <> -__device__ __forceinline__ CudaVecType::type LeakyReluGradGPUFunctor< - float16>::Compute(const CudaVecType::type in, - const CudaVecType::type dout) { - // leakyrelu backward : dx = x > 0 ? dout : alpha * dout - const float2 xx = __half22float2(in); - const float2 yy = __half22float2(dout); - return __floats2half2_rn((xx.x > 0.0f) ? yy.x : alpha_ * yy.x, - (xx.y > 0.0f) ? yy.y : alpha_ * yy.y); -} +template +struct CudaSqrtFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // sqrt(x) = sqrt(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(sqrt(x)); + } +}; -/* ========================================================================== */ +template +struct CudaSqrtGradFunctor : public BaseActivationFunctor { + T one_half = static_cast(0.5f); + + // dx = dout * 0.5 / out + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + return one_half * args[0] / args[1]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; -template -__global__ void ActivationGradKernelVec(const T* forward_data, const T* dout, - T* dx, int num, Functor functor) { - using VecType = typename CudaVecType::type; - constexpr int vecsize = CudaVecType::vecsize; - int idx = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - int loop = num / vecsize; - int tail = num % vecsize; - const VecType* in_forward = reinterpret_cast(forward_data); - const VecType* in_dout = reinterpret_cast(dout); - VecType* out = reinterpret_cast(dx); - VecType forward_vec, dout_vec; - T in_data, dout_data; - for (int i = idx; i < loop; i += stride) { -#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 - forward_vec = __ldg(in_forward + i); - dout_vec = __ldg(in_dout + i); -#else - forward_vec = in_forward[i]; - dout_vec = in_dout[i]; -#endif - out[i] = functor.Compute(forward_vec, dout_vec); - } - - while (idx == loop && tail) { - in_data = forward_data[num - tail]; - dout_data = dout[num - tail]; - dx[num - tail] = functor.ComputeRemainder(in_data, dout_data); - --tail; - } -} - -template -__global__ void ActivationkernelVec(const T* src, T* dst, int num, - Functor functor) { - constexpr int vecsize = CudaVecType::vecsize; - using VecType = typename CudaVecType::type; - int idx = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - int loop = num / vecsize; - int tail = num % vecsize; - const VecType* in = reinterpret_cast(src); - VecType* out = reinterpret_cast(dst); - VecType x_vec; - for (int i = idx; i < loop; i += stride) { -#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 - x_vec = __ldg(in + i); -#else - x_vec = in[i]; -#endif - out[i] = functor.Compute(x_vec); - } - - while (idx == loop && tail) { - dst[num - tail] = functor.ComputeRemainder(src[num - tail]); - --tail; - } -} +template +struct CudaRsqrtFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // rsqrt(x) = rsqrt(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(rsqrt(x)); + } +}; + +template +struct CudaRsqrtGradFunctor : public BaseActivationFunctor { + T minus_one_half = static_cast(-0.5f); + + // dx = dout * -0.5 / out^3 + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + T out = args[1]; + return minus_one_half * args[0] * out * out * out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; template -class ActivationGPUKernel +class ActivationCudaKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& context) const override { - const framework::Tensor* in_x = nullptr; + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor* x = nullptr; framework::Tensor* out = nullptr; - ExtractActivationTensor(context, &in_x, &out); - auto& dev_ctx = context.template device_context(); - - int num = in_x->numel(); - const T* input_data = in_x->data(); - T* output_data = out->mutable_data(dev_ctx.GetPlace(), - static_cast(num * sizeof(T))); - - int block = 512; -#ifdef __HIPCC__ - block = 256; -#endif - Functor functor; + ExtractActivationTensor(ctx, &x, &out); + out->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + std::vector ins = {x}; + std::vector outs = {out}; + auto functor = Functor(); auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { - *attr.second = context.Attr(attr.first); + *attr.second = ctx.Attr(attr.first); } - constexpr int vecsize = CudaVecType::vecsize; - int grid = max((num / vecsize + block - 1) / block, 1); - auto stream = context.cuda_device_context().stream(); - ActivationkernelVec<<>>( - input_data, output_data, num, functor); + LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, + functor); } }; template -class ActivationGradGPUKernel +class ActivationGradCudaKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& context) const override { + void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *x, *out, *d_out; framework::Tensor* d_x = nullptr; x = out = d_out = nullptr; - ExtractActivationGradTensor(context, &x, &out, &d_out, + ExtractActivationGradTensor(ctx, &x, &out, &d_out, &d_x); - int numel = d_out->numel(); - auto& dev_ctx = context.template device_context(); - auto* dx_data = d_x->mutable_data( - dev_ctx.GetPlace(), static_cast(numel * sizeof(T))); - auto* dout_data = d_out->data(); + d_x->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + auto functor = Functor(); + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = ctx.Attr(attr.first); + } + + std::vector ins = {d_out}; + std::vector outs = {d_x}; - auto* forward_data = dout_data; if (static_cast(Functor::FwdDeps()) == static_cast(kDepOut)) { // Only need forward output Out - forward_data = out->data(); + ins.push_back(out); + LaunchElementwiseCudaKernel(dev_ctx, ins, + &outs, functor); } else if (static_cast(Functor::FwdDeps()) == static_cast(kDepX)) { // Only need forward input X - forward_data = x->data(); + ins.push_back(x); + LaunchElementwiseCudaKernel(dev_ctx, ins, + &outs, functor); + } else { + LaunchElementwiseCudaKernel(dev_ctx, ins, + &outs, functor); } - - int block = 512; -#ifdef __HIPCC__ - block = 256; -#endif - - Functor functor; - auto attrs = functor.GetAttrs(); - for (auto& attr : attrs) { - *attr.second = context.Attr(attr.first); - } - constexpr int vecsize = CudaVecType::vecsize; - int grid = max((numel / vecsize + block - 1) / block, 1); - auto stream = context.cuda_device_context().stream(); - ActivationGradKernelVec<<>>( - forward_data, dout_data, dx_data, numel, functor); } }; @@ -395,12 +732,13 @@ class ActivationGradGPUKernel namespace ops = paddle::operators; namespace plat = paddle::platform; -#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ - grad_functor) \ +#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \ + grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ - act_type, \ - ops::ActivationKernel>, \ - ops::ActivationKernel>, \ + act_type, ops::ActivationKernel>, \ + ops::ActivationKernel>, \ ops::ActivationKernel>); \ REGISTER_OP_CUDA_KERNEL( \ @@ -410,28 +748,28 @@ namespace plat = paddle::platform; ops::grad_functor>, \ ops::ActivationGradKernel>); -FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL); -#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \ - grad_functor) \ +#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ + grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ - act_type, ops::ActivationGPUKernel>, \ - ops::ActivationGPUKernel>, \ - ops::ActivationGPUKernel>); \ + act_type, ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>); \ REGISTER_OP_CUDA_KERNEL( \ - act_type##_grad, ops::ActivationGradGPUKernel>, \ - ops::ActivationGradGPUKernel>, \ - ops::ActivationGradGPUKernel>); + act_type##_grad, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>); /* ======================== leaky relu register ============================ */ -REGISTER_ACTIVATION_GPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluGPUFunctor, - LeakyReluGradGPUFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor, + CudaLeakyReluGradFunctor); REGISTER_OP_CUDA_KERNEL( leaky_relu_grad_grad, @@ -444,7 +782,7 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* ======================== elu register ============================ */ -REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor); REGISTER_OP_CUDA_KERNEL( elu_grad_grad, ops::ELUDoubleGradKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); + square, ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( - square_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); + square_grad, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>); REGISTER_OP_CUDA_KERNEL( square_grad_grad, @@ -564,27 +910,29 @@ REGISTER_OP_CUDA_KERNEL( /* ========================== exp register ============================ */ REGISTER_OP_CUDA_KERNEL( - exp, ops::ActivationKernel>, - ops::ActivationKernel>, + exp, ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, ops::ActivationKernel>, ops::ActivationKernel>, - ops::ActivationKernel>); + ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( - exp_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); + exp_grad, ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>); /* ========================================================================== */ /* ========================== Log register ==================================*/ -REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, LogFunctor, LogGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor); REGISTER_OP_CUDA_KERNEL( log_grad_grad, ops::LogDoubleGradKernel>); /* ========================================================================== */ + +REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, + CudaSigmoidGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(silu, Silu, CudaSiluFunctor, + CudaSiluGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, + CudaLogSigmoidGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(atan, Atan, CudaAtanFunctor, + CudaAtanGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(softshrink, SoftShrink, CudaSoftShrinkFunctor, + CudaSoftShrinkGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(ceil, Ceil, CudaCeilFunctor, + CudaZeroGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(floor, Floor, CudaFloorFunctor, + CudaZeroGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(acos, Acos, CudaAcosFunctor, + CudaAcosGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(asin, Asin, CudaAsinFunctor, + CudaAsinGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(sinh, Sinh, CudaSinhFunctor, + CudaSinhGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(cosh, Cosh, CudaCoshFunctor, + CudaCoshGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, CudaRoundFunctor, + CudaZeroGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor, + CudaReciprocalGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(log10, Log10, Log10Functor, Log10GradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(brelu, BRelu, BReluFunctor, BReluGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(soft_relu, SoftRelu, SoftReluFunctor, + SoftReluGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(stanh, STanh, STanhFunctor, STanhGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(softplus, Softplus, SoftplusFunctor, + SoftplusGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(softsign, Softsign, SoftsignFunctor, + SoftsignGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(relu6, Relu6, Relu6Functor, Relu6GradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor, + TanhShrinkGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor, + HardShrinkGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, + HardSigmoidGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(swish, Swish, SwishFunctor, SwishGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(thresholded_relu, ThresholdedRelu, + ThresholdedReluFunctor, + ThresholdedReluGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(hard_swish, HardSwish, HardSwishFunctor, + HardSwishGradFunctor); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 7245dea9cf..ccd5bf528b 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -455,7 +455,7 @@ struct HardShrinkFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); - out.device(d) = x * (temp1 + temp2).template cast(); + out.device(d) = x * (temp1 || temp2).template cast(); } }; @@ -472,7 +472,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); - dx.device(d) = dout * (temp1 + temp2).template cast(); + dx.device(d) = dout * (temp1 || temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -- GitLab