From a9bba5ba2fe490d0bad320aac3dd57c432f3c5db Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 30 Jun 2022 10:32:45 +0800 Subject: [PATCH] [phi]add relu6 kernel and yaml (#43549) * add relu6 kernel and yaml * format files * format code and fix bug * fix build failed --- .../tensorrt/convert/test_activation_op.cc | 2 +- paddle/fluid/operators/activation_op.cc | 1 + paddle/fluid/operators/activation_op.h | 42 +- paddle/fluid/operators/activation_op.kps | 374 +++++++++++------- paddle/phi/kernels/activation_kernel.h | 1 + .../phi/kernels/cpu/activation_grad_kernel.cc | 4 + paddle/phi/kernels/cpu/activation_kernel.cc | 2 + paddle/phi/kernels/funcs/activation_functor.h | 73 ++++ .../phi/kernels/gpu/activation_grad_kernel.cu | 4 + paddle/phi/kernels/gpu/activation_kernel.cu | 2 + paddle/phi/ops/compat/activation_sig.cc | 18 +- .../tests/unittests/test_activation_op.py | 3 +- python/paddle/nn/functional/activation.py | 2 + python/paddle/utils/code_gen/legacy_api.yaml | 10 + .../utils/code_gen/legacy_backward.yaml | 11 + 15 files changed, 348 insertions(+), 201 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc index 7a034f2c166..578b297e21e 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -56,4 +56,4 @@ TEST(Relu6OpConverter, main) { test_activation("relu6"); } USE_OP_ITSELF(relu); USE_OP_ITSELF(sigmoid); USE_OP_ITSELF(tanh); -USE_OP(relu6); +USE_OP_ITSELF(relu6); diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 0da38de5ead..8f443f6f165 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1503,6 +1503,7 @@ REGISTER_ACTIVATION_OP(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, ThresholdedReluGradFunctor); +REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor); REGISTER_ACTIVATION_OP(hard_shrink, HardShrink, HardShrinkFunctor, diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 5d2e82b1046..26ee9eb11bc 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -281,6 +281,7 @@ USE_PHI_DOUBLE_GRAD_FUNCTOR(Tanh) USE_PHI_TRIPLE_GRAD_FUNCTOR(Tanh) USE_PHI_FUNCTOR(BRelu) USE_PHI_FUNCTOR(ThresholdedRelu) +USE_PHI_FUNCTOR(Relu6) USE_PHI_FUNCTOR(LeakyRelu) USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu) USE_PHI_FUNCTOR(HardShrink) @@ -348,44 +349,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor; template using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor; -// relu6(x) = min(max(0, x), 6) -template -struct Relu6Functor : public BaseActivationFunctor { - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - template - void operator()(Device d, X x, Out out) const { - out.device(d) = - x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); - } -}; - -template -struct Relu6GradFunctor : public BaseActivationFunctor { - float threshold; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = - dout * ((out > static_cast(0)) * (out < static_cast(threshold))) - .template cast(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - template struct SoftReluFunctor : public BaseActivationFunctor { float threshold; @@ -561,5 +524,4 @@ struct SoftsignGradFunctor : public BaseActivationFunctor { #define FOR_EACH_ACTIVATION_OP(__macro) \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ - __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ - __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); + __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); diff --git a/paddle/fluid/operators/activation_op.kps b/paddle/fluid/operators/activation_op.kps index 7298a058278..4e56721cb30 100644 --- a/paddle/fluid/operators/activation_op.kps +++ b/paddle/fluid/operators/activation_op.kps @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" - #include "paddle/phi/kernels/funcs/activation_functor.h" namespace paddle { @@ -67,42 +66,6 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor { } }; -template -struct CudaRelu6Functor : public BaseActivationFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - // relu6(x) = min(max(0, x), 6) - __device__ __forceinline__ T operator()(const T x) const { - T t = static_cast(threshold); - return x <= zero ? zero : (x < t ? x : t); - } -}; - -template -struct CudaRelu6GradFunctor : public BaseActivationFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - // dx = (out > 0 && out < t) ? dout : 0 - __device__ __forceinline__ T operator()(const T dout, const T out) const { - T t = static_cast(threshold); - return (out > zero && out < t) ? dout : zero; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - template struct CudaSoftsignFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); @@ -144,8 +107,8 @@ class ActivationCudaKernel for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } }; @@ -158,8 +121,8 @@ class ActivationGradCudaKernel const framework::Tensor *x, *out, *d_out; framework::Tensor* d_x = nullptr; x = out = d_out = nullptr; - ExtractActivationGradTensor(ctx, &x, &out, &d_out, - &d_x); + ExtractActivationGradTensor( + ctx, &x, &out, &d_out, &d_x); d_x->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); auto functor = Functor(); @@ -175,17 +138,17 @@ class ActivationGradCudaKernel static_cast(ActBwdOpFwdDeps::kDepOut)) { // Only need forward output Out ins.push_back(out); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } else if (static_cast(Functor::FwdDeps()) == static_cast(ActBwdOpFwdDeps::kDepX)) { // Only need forward input X ins.push_back(x); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } else { - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } } }; @@ -205,6 +168,7 @@ USE_PHI_FUNCTOR(CudaTanh) USE_PHI_FUNCTOR(CudaBRelu) USE_PHI_FUNCTOR(CudaLeakyRelu) USE_PHI_FUNCTOR(CudaThresholdedRelu) +USE_PHI_FUNCTOR(CudaRelu6) USE_PHI_FUNCTOR(CudaHardShrink) USE_PHI_FUNCTOR(CudaSoftShrink) USE_PHI_FUNCTOR(CudaTanhShrink) @@ -252,61 +216,64 @@ using CudaELUGradNegativeAlphaFunctor = namespace ops = paddle::operators; namespace plat = paddle::platform; -#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ - grad_functor) \ - REGISTER_OP_CUDA_KERNEL( \ - act_type, ops::ActivationCudaKernel>, \ - ops::ActivationCudaKernel>, \ - ops::ActivationCudaKernel>, \ - ops::ActivationCudaKernel>); \ - REGISTER_OP_CUDA_KERNEL( \ - act_type##_grad, \ - ops::ActivationGradCudaKernel>, \ - ops::ActivationGradCudaKernel>, \ - ops::ActivationGradCudaKernel>, \ - ops::ActivationGradCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>); \ + REGISTER_OP_CUDA_KERNEL( \ + act_type##_grad, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>); -#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor, \ - grad_functor) \ - REGISTER_OP_CUDA_KERNEL( \ - act_type, ops::ActivationCudaKernel>, \ - ops::ActivationCudaKernel>, \ - ops::ActivationCudaKernel>, \ - ops::ActivationCudaKernel>, \ - ops::ActivationCudaKernel>, \ - ops::ActivationCudaKernel>); \ - REGISTER_OP_CUDA_KERNEL( \ - act_type##_grad, \ - ops::ActivationGradCudaKernel>, \ - ops::ActivationGradCudaKernel>, \ - ops::ActivationGradCudaKernel>, \ - ops::ActivationGradCudaKernel>, \ - ops::ActivationGradCudaKernel>, \ - ops::ActivationGradCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>); \ + REGISTER_OP_CUDA_KERNEL( \ + act_type##_grad, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>); REGISTER_OP_CUDA_KERNEL( - relu6, ops::ActivationCudaKernel>, + relu6, + ops::ActivationCudaKernel>, ops::ActivationCudaKernel>, ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( - relu6_grad, ops::ActivationGradCudaKernel>, + relu6_grad, + ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - brelu_grad, KP, plat::XPUPlace, + brelu_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(ceil, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(ceil, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - ceil_grad, KP, plat::XPUPlace, + ceil_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - celu, KP, plat::XPUPlace, + celu, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - celu_grad, KP, plat::XPUPlace, + celu_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(elu, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - elu_grad, KP, plat::XPUPlace, + elu_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(exp, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(exp, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - exp_grad, KP, plat::XPUPlace, + exp_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(floor, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(floor, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - floor_grad, KP, plat::XPUPlace, + floor_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - hard_shrink, KP, plat::XPUPlace, + hard_shrink, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - hard_shrink_grad, KP, plat::XPUPlace, + hard_shrink_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - hard_sigmoid, KP, plat::XPUPlace, + hard_sigmoid, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - hard_sigmoid_grad, KP, plat::XPUPlace, + hard_sigmoid_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(hard_swish, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(hard_swish, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - hard_swish_grad, KP, plat::XPUPlace, + hard_swish_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - leaky_relu, KP, plat::XPUPlace, + leaky_relu, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - leaky_relu_grad, KP, plat::XPUPlace, + leaky_relu_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(log, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(log, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - log_grad, KP, plat::XPUPlace, + log_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(log1p, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(log1p, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - log1p_grad, KP, plat::XPUPlace, + log1p_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - logsigmoid, KP, plat::XPUPlace, + logsigmoid, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - logsigmoid_grad, KP, plat::XPUPlace, + logsigmoid_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - reciprocal, KP, plat::XPUPlace, + reciprocal, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - reciprocal_grad, KP, plat::XPUPlace, + reciprocal_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - relu, KP, plat::XPUPlace, + relu, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - relu_grad, KP, plat::XPUPlace, + relu_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(relu6, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(relu6, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - relu6_grad, KP, plat::XPUPlace, + relu6_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(sigmoid, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(sigmoid, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - sigmoid_grad, KP, plat::XPUPlace, + sigmoid_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(silu, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(silu, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - silu_grad, KP, plat::XPUPlace, + silu_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(soft_relu, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(soft_relu, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - soft_relu_grad, KP, plat::XPUPlace, + soft_relu_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(softplus, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(softplus, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - softplus_grad, KP, plat::XPUPlace, + softplus_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - softshrink, KP, plat::XPUPlace, + softshrink, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - softshrink_grad, KP, plat::XPUPlace, + softshrink_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(softsign, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(softsign, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - softsign_grad, KP, plat::XPUPlace, + softsign_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(sqrt, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(sqrt, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - sqrt_grad, KP, plat::XPUPlace, + sqrt_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(square, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(square, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - square_grad, KP, plat::XPUPlace, + square_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); -REGISTER_OP_KERNEL(swish, KP, plat::XPUPlace, +REGISTER_OP_KERNEL(swish, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - swish_grad, KP, plat::XPUPlace, + swish_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( - thresholded_relu, KP, plat::XPUPlace, + thresholded_relu, + KP, + plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( - thresholded_relu_grad, KP, plat::XPUPlace, + thresholded_relu_grad, + KP, + plat::XPUPlace, ops::ActivationGradCudaKernel>); diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 8a1eacd3709..5cc4357c937 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -75,6 +75,7 @@ DECLARE_ACTIVATION_KERNEL(Negative) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) +DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 894f959c131..c498fa48706 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -156,6 +156,9 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, ThresholdedReluGradFunctor, threshold); +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6, + Relu6GradFunctor, + threshold); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, SoftShrinkGradFunctor, lambda); @@ -263,6 +266,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(brelu_grad, BReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, ThresholdedReluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index aca97340df4..8560576a425 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -95,6 +95,7 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, ThresholdedReluFunctor, threshold) +DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6Functor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) @@ -147,6 +148,7 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) +PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 9cfd9c10836..bd5e3dec3d6 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1097,6 +1097,44 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +// relu6(x) = min(max(0, x), 6) +template +struct Relu6Functor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Out out) const { + out.device(d) = + x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); + } +}; + +template +struct Relu6GradFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = + dout * ((out > static_cast(0)) * (out < static_cast(threshold))) + .template cast(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template @@ -2712,6 +2750,41 @@ struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaRelu6Functor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // relu6(x) = min(max(0, x), 6) + __device__ __forceinline__ T operator()(const T x) const { + T t = static_cast(threshold); + return x <= zero ? zero : (x < t ? x : t); + } +}; + +template +struct CudaRelu6GradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // dx = (out > 0 && out < t) ? dout : 0 + __device__ __forceinline__ T operator()(const T dout, const T out) const { + T t = static_cast(threshold); + return (out > zero && out < t) ? dout : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; template struct CudaLeakyReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index a127c9e9135..a53c2a05d83 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -223,6 +223,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, CudaCELUGradFunctor, alpha); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6, + CudaRelu6GradFunctor, + threshold); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, CudaBReluGradFunctor, @@ -348,6 +351,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad, LeakyReluDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, ThresholdedReluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 55df4268e00..b7ff76f7446 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -111,6 +111,7 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, CudaThresholdedReluFunctor, threshold) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, CudaRelu6Functor, threshold) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, CudaHardShrinkFunctor, threshold) @@ -192,6 +193,7 @@ PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) +PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel) diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index ac0c3021a32..93717e71569 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -82,14 +82,15 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Softplus, "softplus", "beta" comma "threshold"); // NOLINT -DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", ); // NOLINT -DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Tanh, "tanh", ); // NOLINT -DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Sigmoid, "sigmoid", ); // NOLINT -DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Exp, "exp", ); // NOLINT -DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Expm1, "expm1", ); // NOLINT -DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Reciprocal, "reciprocal", ); // NOLINT -DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Sqrt, "sqrt", ); // NOLINT -DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Rsqrt, "rsqrt", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Tanh, "tanh", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Sigmoid, "sigmoid", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Exp, "exp", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Expm1, "expm1", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Reciprocal, "reciprocal", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Sqrt, "sqrt", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Rsqrt, "rsqrt", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu6, "relu6", "threshold"); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(HardSigmoid, "hard_sigmoid", @@ -282,6 +283,7 @@ PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad_grad, phi::LeakyReluDoubleGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(thresholded_relu_grad, phi::ThresholdedReluGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(relu6_grad, phi::Relu6GradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(softshrink_grad, phi::SoftShrinkGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(hard_shrink_grad, diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index dd242920db3..4b817bb22eb 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1737,6 +1737,7 @@ class TestRelu6(TestActivation): def setUp(self): self.op_type = "relu6" self.init_dtype() + self.python_api = paddle.nn.functional.relu6 np.random.seed(1024) x = np.random.uniform(-1, 10, [10, 12]).astype(self.dtype) @@ -1750,7 +1751,7 @@ class TestRelu6(TestActivation): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestRelu6API(unittest.TestCase): diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index aed8fbb0f58..7c2dff5247e 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -860,6 +860,8 @@ def relu6(x, name=None): out = F.relu6(x) # [0, 0.3, 6] """ threshold = 6.0 + if in_dygraph_mode(): + return _C_ops.final_state_relu6(x, threshold) if in_dynamic_mode(): return _C_ops.relu6(x, 'threshold', threshold) diff --git a/python/paddle/utils/code_gen/legacy_api.yaml b/python/paddle/utils/code_gen/legacy_api.yaml index f5c261dc7e3..8d20833c652 100644 --- a/python/paddle/utils/code_gen/legacy_api.yaml +++ b/python/paddle/utils/code_gen/legacy_api.yaml @@ -1713,6 +1713,16 @@ inplace : (x -> out) backward : relu_grad +- api : relu6 + args : (Tensor x, float threshold) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : relu6 + backward : relu6_grad + - api : reshape args : (Tensor x, IntArray shape) output : Tensor(out), Tensor(xshape) diff --git a/python/paddle/utils/code_gen/legacy_backward.yaml b/python/paddle/utils/code_gen/legacy_backward.yaml index c110d6f563b..16d58fde77f 100644 --- a/python/paddle/utils/code_gen/legacy_backward.yaml +++ b/python/paddle/utils/code_gen/legacy_backward.yaml @@ -1565,6 +1565,17 @@ kernel : func : prod_grad +- backward_api : relu6_grad + forward : relu6 (Tensor x, float threshold) -> Tensor(out) + args : (Tensor out, Tensor out_grad, float threshold) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] + kernel : + func : relu6_grad + inplace : (out_grad -> x_grad) + - backward_api : relu_double_grad forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) args : (Tensor out, Tensor grad_x_grad) -- GitLab