diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 836c5fa06f6dfe63a134d75fde91eb1a061ce1f0..22f8147111ffa5be91813738ff147a19b9ef22bc 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -663,6 +663,640 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; +template +struct CudaLog1pFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // log1p(x) = log(1 + x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(log(one + x)); + } +}; + +template +struct CudaLog1pGradFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // dx = dout / (1 + x) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / (one + args[1]); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaLog2Functor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // log2(x) = log2(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(log2(x)); + } +}; + +template +struct CudaLog2GradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + T log_two = static_cast(log(static_cast(2.0f))); + + // dx = dout / (x * log(2)) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / (args[1] * log_two); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaLog10Functor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // log10(x) = log10(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(log10(x)); + } +}; + +template +struct CudaLog10GradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + T log_ten = static_cast(log(static_cast(10.0f))); + + // dx = dout / (x * log(10)) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / (args[1] * log_ten); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaBReluFunctor : public BaseActivationFunctor { + float t_min; + float t_max; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"t_min", &t_min}, {"t_max", &t_max}}; + } + + // brelu(x) = min(max(x, t_min), t_max) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[0]; + T t_min_cast = static_cast(t_min); + T t_max_cast = static_cast(t_max); + T temp_max = x > t_min_cast ? x : t_min_cast; + T temp_min = temp_max < t_max_cast ? temp_max : t_max_cast; + return temp_min; + } +}; + +template +struct CudaBReluGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float t_min; + float t_max; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"t_min", &t_min}, {"t_max", &t_max}}; + } + + // dx = (x > t_min && x < t_max) ? dout : 0 + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T dout = args[0]; + T x = args[1]; + T t_min_cast = static_cast(t_min); + T t_max_cast = static_cast(t_max); + return (x > t_min_cast && x < t_max_cast) ? dout : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaSoftReluFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold))) + // Inputs: args[0], the input x + // threshold should not be negative + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + MPType t = static_cast(threshold); + MPType temp_min = x < t ? x : t; + MPType temp_max = temp_min > -t ? temp_min : -t; + return static_cast(log(one + exp(temp_max))); + } +}; + +template +struct CudaSoftReluGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0 + // Inputs: args[0], the input dout + // args[1], the input out + // threshold should not be negative + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType out = static_cast(args[1]); + MPType t = static_cast(threshold); + return (out > -t && out < t) ? static_cast(dout * (one - exp(-out))) + : static_cast(0.0f); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + +template +struct CudaSTanhFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + float scale_a; + float scale_b; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + // stanh(x) = b * tanh(a * x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + MPType a = static_cast(scale_a); + MPType b = static_cast(scale_b); + return static_cast(b * tanh(a * x)); + } +}; + +template +struct CudaSTanhGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float scale_a; + float scale_b; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x)) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + MPType a = static_cast(scale_a); + MPType b = static_cast(scale_b); + MPType temp = tanh(a * x); + return static_cast(dout * a * b * (one - temp * temp)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaSoftplusFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta; + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + MPType b = static_cast(beta); + MPType t = static_cast(threshold); + MPType x_beta = x * beta; + return static_cast(x_beta > t ? x : log(one + exp(x_beta)) / b); + } +}; + +template +struct CudaSoftplusGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta; + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x)) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + MPType b = static_cast(beta); + MPType t = static_cast(threshold); + MPType x_beta = x * beta; + return x_beta > t ? args[0] : static_cast(dout / (one + exp(-x_beta))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaSoftsignFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // softsign(x) = x / (1 + abs(x)) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / (one + abs(args[0])); + } +}; + +template +struct CudaSoftsignGradFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // dx = dout / (1 + abs(x))^2 + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T temp = one + abs(args[1]); + return args[0] / (temp * temp); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaRelu6Functor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // relu6(x) = min(max(0, x), 6) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T t = static_cast(threshold); + return args[0] <= zero ? zero : (args[0] < t ? args[0] : t); + } +}; + +template +struct CudaRelu6GradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // dx = (out > 0 && out < t) ? dout : 0 + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + T t = static_cast(threshold); + return (args[1] > zero && args[1] < t) ? args[0] : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + +template +struct CudaTanhShrinkFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // tanhshrink(x) = x - tanh(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(x - tanh(x)); + } +}; + +template +struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // dx = dout * tanh(x)^2 + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout * tanh(x) * tanh(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaHardShrinkFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[0]; + T t = static_cast(threshold); + return (x > -t && x < t) ? zero : x; + } +}; + +template +struct CudaHardShrinkGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // dx = (x > -threshold && x < threshold) ? 0 : dout + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[1]; + T t = static_cast(threshold); + return (x > -t && x < t) ? zero : args[0]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaHardSigmoidFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + T one = static_cast(1.0f); + float slope; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"slope", &slope}, {"offset", &offset}}; + } + + // hard_sigmoid(x) = 0, when x <= -3 + // 1, when x >= 3 + // x * slope + offset, otherwise + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T temp = args[0] * static_cast(slope) + static_cast(offset); + T temp_max = temp > zero ? temp : zero; + T temp_min = temp_max < one ? temp_max : one; + return temp_min; + } +}; + +template +struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + T one = static_cast(1.0f); + float slope; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"slope", &slope}, {"offset", &offset}}; + } + + // dx = (out > 0 && out < 1) ? dout * slope : 0 + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + T out = args[1]; + return (out > zero && out < one) ? args[0] * static_cast(slope) : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + +template +struct CudaSwishFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + // swish(x) = x / (1 + exp(-beta * x)) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + MPType b = static_cast(beta); + return static_cast(x / (one + exp(-b * x))); + } +}; + +template +struct CudaSwishGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + MPType b = static_cast(beta); + MPType temp1 = one / (one + exp(-b * x)); + MPType out = x * temp1; + MPType temp2 = b * out; + MPType temp3 = temp1 * (one - temp2); + return static_cast(dout * (temp2 + temp3)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaThresholdedReluFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // thresholded_relu(x) = x > threshold ? x : 0 + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] > static_cast(threshold) ? args[0] : zero; + } +}; + +template +struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // dx = x > threshold ? dout : 0 + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + return args[1] > static_cast(threshold) ? args[0] : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaHardSwishFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + // hard_swish(x) = 0, when x <= -offset + // x , when x >= threshold - offset + // x * (x + offset) / scale, otherwise + // threshold = scale = 6, offset = 3 by default + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[0]; + T t = static_cast(threshold); + T temp = x + static_cast(offset); + T temp_max = temp > zero ? temp : zero; + T temp_min = temp_max < t ? temp_max : t; + return temp_min * x / static_cast(scale); + } +}; + +template +struct CudaHardSwishGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + T one = static_cast(1.0f); + T two = static_cast(2.0f); + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + // dx = 0, when x <= -offset + // dout , when x >= threshold - offset + // dout * (2 * x / scale + offset / scale), otherwise + // threshold = scale = 6, offset = 3 by default + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[1]; + T o = static_cast(offset); + T s = static_cast(scale); + T temp1 = static_cast(x + o > zero); + T temp2 = static_cast(x + o < static_cast(threshold)); + return args[0] * (temp1 * temp2 * (two * x + o) / s + one - temp2); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CudaELUFunctor : public BaseActivationFunctor { + using CT = typename details::MPTypeTrait::Type; + CT zero = static_cast(0.0f); + CT one = static_cast(1.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1)) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + CT temp = static_cast(alpha) * (exp(x) - one); + CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp); + return static_cast(res); + } +}; + +template +struct CudaELUGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); + MPType one = static_cast(1.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // dx = dout, if alpha > 0 and x > 0 + // dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0 + // dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0 + // dx = 0, if alpha <= 0 and x <=0 + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + MPType a = static_cast(alpha); + MPType temp_a_pos = static_cast(alpha > 0.0f); + MPType temp_a_neg = static_cast(alpha <= 0.0f); + MPType temp_x_pos = static_cast(x > zero); + MPType temp_x_neg = static_cast(x <= zero); + return static_cast( + dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) + + temp_a_neg * temp_x_pos * (one + a * exp(x)))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template class ActivationCudaKernel : public framework::OpKernel { @@ -732,23 +1366,6 @@ class ActivationGradCudaKernel namespace ops = paddle::operators; namespace plat = paddle::platform; -#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \ - grad_functor) \ - REGISTER_OP_CUDA_KERNEL( \ - act_type, ops::ActivationKernel>, \ - ops::ActivationKernel>, \ - ops::ActivationKernel>); \ - REGISTER_OP_CUDA_KERNEL( \ - act_type##_grad, ops::ActivationGradKernel>, \ - ops::ActivationGradKernel>, \ - ops::ActivationGradKernel>); - #define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ @@ -767,6 +1384,32 @@ namespace plat = paddle::platform; ops::ActivationGradCudaKernel>); +#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor, \ + grad_functor) \ + REGISTER_OP_CUDA_KERNEL( \ + act_type, ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>); \ + REGISTER_OP_CUDA_KERNEL( \ + act_type##_grad, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>); + /* ======================== leaky relu register ============================ */ REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor, CudaLeakyReluGradFunctor); @@ -782,7 +1425,7 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* ======================== elu register ============================ */ -REGISTER_ACTIVATION_GPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor); REGISTER_OP_CUDA_KERNEL( elu_grad_grad, ops::ELUDoubleGradKernel>, - ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>); -REGISTER_OP_CUDA_KERNEL( - square_grad, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>); +REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor, + CudaSquareGradFunctor); REGISTER_OP_CUDA_KERNEL( square_grad_grad, @@ -890,7 +1512,6 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* ========================== pow register ============================ */ - REGISTER_OP_CUDA_KERNEL( pow, ops::PowKernel>, ops::PowKernel>, @@ -908,7 +1529,6 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* ========================== exp register ============================ */ - REGISTER_OP_CUDA_KERNEL( exp, ops::ActivationCudaKernel>, @@ -943,56 +1563,44 @@ REGISTER_OP_CUDA_KERNEL( ops::LogGradGradFunctor>); /* ========================================================================== */ -REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, - CudaSigmoidGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(silu, Silu, CudaSiluFunctor, - CudaSiluGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, - CudaLogSigmoidGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(atan, Atan, CudaAtanFunctor, - CudaAtanGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(softshrink, SoftShrink, CudaSoftShrinkFunctor, - CudaSoftShrinkGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(ceil, Ceil, CudaCeilFunctor, - CudaZeroGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(floor, Floor, CudaFloorFunctor, - CudaZeroGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(acos, Acos, CudaAcosFunctor, - CudaAcosGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(asin, Asin, CudaAsinFunctor, - CudaAsinGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(sinh, Sinh, CudaSinhFunctor, - CudaSinhGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(cosh, Cosh, CudaCoshFunctor, - CudaCoshGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, CudaRoundFunctor, - CudaZeroGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor, - CudaReciprocalGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(log10, Log10, Log10Functor, Log10GradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(brelu, BRelu, BReluFunctor, BReluGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(soft_relu, SoftRelu, SoftReluFunctor, - SoftReluGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(stanh, STanh, STanhFunctor, STanhGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(softplus, Softplus, SoftplusFunctor, - SoftplusGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(softsign, Softsign, SoftsignFunctor, - SoftsignGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(relu6, Relu6, Relu6Functor, Relu6GradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor, - TanhShrinkGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor, - HardShrinkGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, - HardSigmoidGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(swish, Swish, SwishFunctor, SwishGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(thresholded_relu, ThresholdedRelu, - ThresholdedReluFunctor, - ThresholdedReluGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(hard_swish, HardSwish, HardSwishFunctor, - HardSwishGradFunctor); +#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ + __macro(sigmoid, Sigmoid, CudaSigmoidFunctor, CudaSigmoidGradFunctor); \ + __macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor); \ + __macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \ + CudaLogSigmoidGradFunctor); \ + __macro(atan, Atan, CudaAtanFunctor, CudaAtanGradFunctor); \ + __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ + CudaSoftShrinkGradFunctor); \ + __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \ + __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \ + __macro(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); \ + __macro(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); \ + __macro(acos, Acos, CudaAcosFunctor, CudaAcosGradFunctor); \ + __macro(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); \ + __macro(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor); \ + __macro(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor); \ + __macro(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor); \ + __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \ + __macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \ + CudaReciprocalGradFunctor); \ + __macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor); \ + __macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor); \ + __macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor); \ + __macro(brelu, BRelu, CudaBReluFunctor, CudaBReluGradFunctor); \ + __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \ + __macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor); \ + __macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor); \ + __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor); \ + __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor); \ + __macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor, \ + CudaTanhShrinkGradFunctor); \ + __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \ + CudaHardShrinkGradFunctor); \ + __macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor, \ + CudaHardSigmoidGradFunctor); \ + __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \ + __macro(thresholded_relu, ThresholdedRelu, CudaThresholdedReluFunctor, \ + CudaThresholdedReluGradFunctor); \ + __macro(hard_swish, HardSwish, CudaHardSwishFunctor, \ + CudaHardSwishGradFunctor); +FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 92465c3e28401aaa3dc63cda4a89f43d83d959b7..31589ca4ae38e83307411ecbc39e8ea815987f03 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -2718,7 +2718,7 @@ create_test_act_fp16_class(TestRelu) create_test_act_fp16_class(TestGelu) create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestRelu6) -create_test_act_fp16_class(TestSoftRelu) +create_test_act_fp16_class(TestSoftRelu, grad_atol=0.85) create_test_act_fp16_class(TestELU) create_test_act_fp16_class(TestReciprocal) create_test_act_fp16_class(TestLog) @@ -2736,7 +2736,7 @@ create_test_act_fp16_class(TestSoftplus) create_test_act_fp16_class(TestSoftsign) create_test_act_fp16_class(TestThresholdedRelu) create_test_act_fp16_class(TestHardSigmoid) -create_test_act_fp16_class(TestSwish) +create_test_act_fp16_class(TestSwish, grad_atol=0.85) create_test_act_fp16_class(TestHardSwish) if __name__ == "__main__":