From 3e636ec97e37e39847641acaa86c3b03846dc0d0 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 <75946871+zhangyuqin1998@users.noreply.github.com> Date: Wed, 15 Mar 2023 10:32:44 +0800 Subject: [PATCH] Delete hardswish_raw op (#51634) * Delete hardswish_raw op * fix ut --- paddle/phi/api/yaml/legacy_backward.yaml | 2 +- paddle/phi/kernels/activation_grad_kernel.h | 3 --- paddle/phi/kernels/activation_kernel.cc | 25 ------------------- paddle/phi/kernels/activation_kernel.h | 8 ------ .../phi/kernels/cpu/activation_grad_kernel.cc | 6 ++--- paddle/phi/kernels/cpu/activation_kernel.cc | 14 +++++------ .../phi/kernels/gpu/activation_grad_kernel.cu | 6 ++--- paddle/phi/kernels/gpu/activation_kernel.cu | 14 +++++------ .../kernels/onednn/activation_grad_kernel.cc | 3 --- .../phi/kernels/onednn/activation_kernel.cc | 13 ++++------ .../phi/kernels/xpu/activation_grad_kernel.cc | 6 ++--- paddle/phi/kernels/xpu/activation_kernel.cc | 14 +++++------ paddle/phi/ops/compat/activation_sig.cc | 12 ++++----- .../inference/test_trt_convert_hard_swish.py | 6 ++--- 14 files changed, 45 insertions(+), 87 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 46c7a81db0..53cf1945f2 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -532,7 +532,7 @@ - backward_op : hardswish_grad forward : hardswish (Tensor x) -> Tensor(out) - args : (Tensor x, Tensor out_grad, float threshold = 6.0, float scale = 6.0, float offset = 3.0) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index b65a2304ca..b322ed5e02 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -225,9 +225,6 @@ template void HardSwishGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, - float threshold, - float scale, - float offset, DenseTensor* dx); template diff --git a/paddle/phi/kernels/activation_kernel.cc b/paddle/phi/kernels/activation_kernel.cc index 3de8a867fd..ef6135d25c 100644 --- a/paddle/phi/kernels/activation_kernel.cc +++ b/paddle/phi/kernels/activation_kernel.cc @@ -19,13 +19,6 @@ namespace phi { -template -void HardSwishKernel(const Context& dev_ctx, - const DenseTensor& x, - DenseTensor* out) { - HardSwishRawKernel(dev_ctx, x, 6, 6, 3, out); -} - template void Relu6Kernel(const Context& dev_ctx, const DenseTensor& x, @@ -44,21 +37,10 @@ void SwishKernel(const Context& dev_ctx, using complex64 = ::phi::dtype::complex; using complex128 = ::phi::dtype::complex; -PD_REGISTER_KERNEL( - hardswish, CPU, ALL_LAYOUT, phi::HardSwishKernel, float, double) {} PD_REGISTER_KERNEL(relu6, CPU, ALL_LAYOUT, phi::Relu6Kernel, float, double) {} PD_REGISTER_KERNEL(swish, CPU, ALL_LAYOUT, phi::SwishKernel, float, double) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(hardswish, - GPU, - ALL_LAYOUT, - phi::HardSwishKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} - PD_REGISTER_KERNEL(relu6, GPU, ALL_LAYOUT, @@ -80,18 +62,11 @@ PD_REGISTER_KERNEL(swish, #endif #if defined PADDLE_WITH_XPU -PD_REGISTER_KERNEL(hardswish, XPU, ALL_LAYOUT, phi::HardSwishKernel, float) {} PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {} PD_REGISTER_KERNEL(swish, XPU, ALL_LAYOUT, phi::SwishKernel, float) {} #endif #ifdef PADDLE_WITH_MKLDNN -PD_REGISTER_KERNEL(hardswish, - OneDNN, - ONEDNN, - phi::HardSwishKernel, - float, - phi::dtype::bfloat16) {} PD_REGISTER_KERNEL( relu6, OneDNN, ONEDNN, phi::Relu6Kernel, float, phi::dtype::bfloat16) {} PD_REGISTER_KERNEL( diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 9ea8423253..0d7ec8e8b7 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -90,14 +90,6 @@ DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) -template -void HardSwishRawKernel(const Context& dev_ctx, - const DenseTensor& x, - float threshold, - float scale, - float offset, - DenseTensor* out); - template void HardSwishKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 1f3e8b4cc7..e15ae5bb89 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -226,11 +226,11 @@ template void HardSwishGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, - float threshold, - float scale, - float offset, DenseTensor* dx) { funcs::HardSwishGradFunctor functor; + float threshold = 6; + float scale = 6; + float offset = 3; auto attrs = functor.GetAttrs(); *(attrs[0].second) = threshold; *(attrs[1].second) = scale; diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 70b011eafe..355dc3547f 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -113,13 +113,13 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, offset) template -void HardSwishRawKernel(const Context& dev_ctx, - const DenseTensor& x, - float threshold, - float scale, - float offset, - DenseTensor* out) { +void HardSwishKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { funcs::HardSwishFunctor functor; + float threshold = 6; + float scale = 6; + float offset = 3; auto attrs = functor.GetAttrs(); *(attrs[0].second) = threshold; *(attrs[1].second) = scale; @@ -183,7 +183,7 @@ PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) -PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel) +PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index fc7bf8b1cc..617fbd45f0 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -274,11 +274,11 @@ template void HardSwishGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, - float threshold, - float scale, - float offset, DenseTensor* dx) { funcs::CudaHardSwishGradFunctor functor; + float threshold = 6; + float scale = 6; + float offset = 3; auto attrs = functor.GetAttrs(); *(attrs[0].second) = threshold; *(attrs[1].second) = scale; diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 0b396b17f5..c60a937255 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -138,13 +138,13 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Selu, CudaSeluFunctor, scale, alpha) template -void HardSwishRawKernel(const Context& dev_ctx, - const DenseTensor& x, - float threshold, - float scale, - float offset, - DenseTensor* out) { +void HardSwishKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { funcs::CudaHardSwishFunctor functor; + float threshold = 6; + float scale = 6; + float offset = 3; auto attrs = functor.GetAttrs(); *(attrs[0].second) = threshold; *(attrs[1].second) = scale; @@ -257,7 +257,7 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) -PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel) +PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) diff --git a/paddle/phi/kernels/onednn/activation_grad_kernel.cc b/paddle/phi/kernels/onednn/activation_grad_kernel.cc index 489f53da76..6355908c25 100644 --- a/paddle/phi/kernels/onednn/activation_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_grad_kernel.cc @@ -238,9 +238,6 @@ template void HardSwishGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, - float threshold, - float scale, - float offset, DenseTensor* dx) { HardSwishOneDNNGradFunctor functor; functor(dev_ctx, x, dout, 0, 0, dx); diff --git a/paddle/phi/kernels/onednn/activation_kernel.cc b/paddle/phi/kernels/onednn/activation_kernel.cc index 0eb9b4acdc..fda32f7617 100644 --- a/paddle/phi/kernels/onednn/activation_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_kernel.cc @@ -157,14 +157,11 @@ DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(SwishRaw, SwishOneDNNFunctor, beta) template -void HardSwishRawKernel(const Context& dev_ctx, - const DenseTensor& x, - float threshold, - float scale, - float offset, - DenseTensor* out) { +void HardSwishKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { HardSwishOneDNNFunctor functor; - functor(dev_ctx, x, threshold, 0, out); + functor(dev_ctx, x, 6, 0, out); } template @@ -202,7 +199,7 @@ PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel) -PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel) +PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel) diff --git a/paddle/phi/kernels/xpu/activation_grad_kernel.cc b/paddle/phi/kernels/xpu/activation_grad_kernel.cc index 8b2946d54c..df9674a16d 100644 --- a/paddle/phi/kernels/xpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_grad_kernel.cc @@ -603,11 +603,11 @@ template void HardSwishGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, - float threshold, - float scale, - float offset, DenseTensor* dx) { XPUHardSwishGradFunctor functor; + float threshold = 6; + float scale = 6; + float offset = 3; auto attrs = functor.GetAttrs(); *(attrs[0].second) = threshold; *(attrs[1].second) = scale; diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 4fc2e56532..490c56d131 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -513,13 +513,13 @@ DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, offset) template -void HardSwishRawKernel(const Context& dev_ctx, - const DenseTensor& x, - float threshold, - float scale, - float offset, - DenseTensor* out) { +void HardSwishKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { XPUHardSwishFunctor functor; + float threshold = 6; + float scale = 6; + float offset = 3; auto attrs = functor.GetAttrs(); *(attrs[0].second) = threshold; *(attrs[1].second) = scale; @@ -551,7 +551,7 @@ PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) -PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel) +PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index 75bf5b11f7..5106c63a9e 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -41,10 +41,6 @@ namespace phi { DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hardtanh", "t_min" comma "t_max"); DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold"); -DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish, - "hardswish", - "threshold" comma "scale" comma - "offset"); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Swish, "swish", "beta"); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(STanh, @@ -53,9 +49,13 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(STanh, DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu6, "relu6", "threshold"); // NOLINT +KernelSignature HardSwishGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("hardswish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); +} + KernelSignature HardSwishOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "hardswish_raw", {"X"}, {"threshold", "scale", "offset"}, {"Out"}); + return KernelSignature("hardswish", {"X"}, {}, {"Out"}); } KernelSignature SwishOpArgumentMapping(const ArgumentMappingContext& ctx) { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_hard_swish.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_hard_swish.py index 235925e832..3c3a98ee1e 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_hard_swish.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_hard_swish.py @@ -40,9 +40,9 @@ class TrtConvertHardSwishTest(TrtLayerAutoScanTest): def generate_input1(attrs: List[Dict[str, Any]]): return np.ones([1, 3, 32, 32]).astype(np.float32) - for threshold in [6.0, 7.0, 100.0, 0.0, -1.0]: - for scale in [5.0, 7.0, -1.0, 0.0, 100.0]: - for offset in [3.0, 5.0, -1.0, 0.0, 100.0]: + for threshold in [6.0]: + for scale in [6.0]: + for offset in [3.0]: dics = [ { "threshold": threshold, -- GitLab