From 56d46ccc420c52a54d6ee2faef5acb00c6d599fd Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 <75946871+zhangyuqin1998@users.noreply.github.com> Date: Wed, 19 Jul 2023 09:37:38 +0800 Subject: [PATCH] delete relu6_raw (#55383) * delete relu6_raw * fix codestyle * Update test_mkldnn_matmul_activation_fuse_pass.py * fix * Update backward.yaml * Update ops.yaml * Update backward.yaml --- paddle/phi/api/yaml/backward.yaml | 11 ++++ paddle/phi/api/yaml/legacy_backward.yaml | 11 ---- paddle/phi/api/yaml/legacy_ops.yaml | 10 ---- paddle/phi/api/yaml/op_compat.yaml | 2 +- paddle/phi/api/yaml/ops.yaml | 10 ++++ paddle/phi/api/yaml/static_backward.yaml | 11 ---- paddle/phi/api/yaml/static_ops.yaml | 10 ---- paddle/phi/kernels/activation_kernel.cc | 54 ------------------- paddle/phi/kernels/activation_kernel.h | 1 - paddle/phi/kernels/cpu/activation_kernel.cc | 14 ++++- paddle/phi/kernels/gpu/activation_kernel.cu | 14 ++++- .../phi/kernels/onednn/activation_kernel.cc | 11 ++-- paddle/phi/kernels/sparse/cpu/unary_kernel.cc | 1 - paddle/phi/kernels/sparse/gpu/unary_kernel.cu | 1 - .../kernels/sparse/impl/unary_kernel_impl.h | 16 +----- paddle/phi/kernels/sparse/unary_kernel.h | 1 - paddle/phi/kernels/xpu/activation_kernel.cc | 22 ++++---- .../test_conv_act_onednn_fuse_pass.py | 2 +- ...test_mkldnn_matmul_activation_fuse_pass.py | 2 +- ...ul_elementwise_add_activation_fuse_pass.py | 2 +- ...t_mkldnn_matmul_v2_activation_fuse_pass.py | 2 +- ...onednn_conv_concat_activation_fuse_pass.py | 2 +- ...nn_elementwise_add_activation_fuse_pass.py | 2 +- ...st_onednn_softplus_activation_fuse_pass.py | 2 +- 24 files changed, 71 insertions(+), 143 deletions(-) delete mode 100644 paddle/phi/kernels/activation_kernel.cc diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index b6347e35fcc..2ea1660e746 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1740,6 +1740,17 @@ func : reciprocal_grad inplace : (out_grad -> x_grad) +- backward_op : relu6_grad + forward : relu6 (Tensor x) -> Tensor(out) + args : (Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] + kernel : + func : relu6_grad + inplace : (out_grad -> x_grad) + - backward_op : relu_double_grad forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) args : (Tensor out, Tensor grad_x_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 6651b1bcabd..246c14fc2d8 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -485,17 +485,6 @@ func : prod_grad composite: prod_grad(x, out, out_grad, dims, keep_dim, reduce_all, x_grad) -- backward_op : relu6_grad - forward : relu6 (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [out] - kernel : - func : relu6_grad - inplace : (out_grad -> x_grad) - - backward_op : repeat_interleave_grad forward : repeat_interleave(Tensor x, int repeats, int axis) -> Tensor(out) args : (Tensor x, Tensor out_grad, int repeats, int axis) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 0614007fb99..110928d0605 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -720,16 +720,6 @@ data_type : dtype backend : place -- op : relu6 - args : (Tensor x) - output : Tensor - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : relu6 - backward : relu6_grad - - op : remainder args : (Tensor x, Tensor y) output : Tensor (out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 590808a6bd8..2026e58e8ec 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2251,7 +2251,7 @@ outputs : out : Out extra : - attrs : [bool use_mkldnn = false] + attrs : [bool use_mkldnn = false, float threshold = 6.0] - op : remainder (elementwise_mod) inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 4e4d4803b21..4f80635d0e8 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1996,6 +1996,16 @@ inplace : (x -> out) backward : relu_grad +- op : relu6 + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : relu6 + backward : relu6_grad + - op : renorm args : (Tensor x, float p, int axis, float max_norm) output : Tensor diff --git a/paddle/phi/api/yaml/static_backward.yaml b/paddle/phi/api/yaml/static_backward.yaml index ebbca84b67d..526a7195a5b 100755 --- a/paddle/phi/api/yaml/static_backward.yaml +++ b/paddle/phi/api/yaml/static_backward.yaml @@ -278,17 +278,6 @@ func : prod_grad composite: prod_grad(x, out, out_grad, dims, keep_dim, reduce_all, x_grad) -- backward_op : relu6_grad - forward : relu6 (Tensor x, float threshold = 6.0f) -> Tensor(out) - args : (Tensor out, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [out] - kernel : - func : relu6_grad - inplace : (out_grad -> x_grad) - - backward_op : rnn_grad forward : rnn (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false) -> Tensor(out), Tensor(dropout_state_out), Tensor[](state), Tensor(reserve) args : (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor out, Tensor dropout_state_out, Tensor reserve, Tensor out_grad, Tensor[] state_grad, float dropout_prob, bool is_bidirec, int input_size, int hidden_size, int num_layers, str mode, int seed, bool is_test) diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 069fdd7289f..216fca178fd 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -544,16 +544,6 @@ func : reduce_scatter param: [x, nranks] -- op : relu6 - args : (Tensor x, float threshold = 6.0f) - output : Tensor - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : relu6_raw - backward : relu6_grad - - op : remainder args : (Tensor x, Tensor y, int axis = -1) output : Tensor (out) diff --git a/paddle/phi/kernels/activation_kernel.cc b/paddle/phi/kernels/activation_kernel.cc deleted file mode 100644 index f157c5e054b..00000000000 --- a/paddle/phi/kernels/activation_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/activation_kernel.h" - -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace phi { - -template -void Relu6Kernel(const Context& dev_ctx, - const DenseTensor& x, - DenseTensor* out) { - Relu6RawKernel(dev_ctx, x, 6, out); -} - -} // namespace phi -using complex64 = ::phi::dtype::complex; -using complex128 = ::phi::dtype::complex; - -PD_REGISTER_KERNEL(relu6, CPU, ALL_LAYOUT, phi::Relu6Kernel, float, double) {} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(relu6, - GPU, - ALL_LAYOUT, - phi::Relu6Kernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} -#endif - -#if defined PADDLE_WITH_XPU -PD_REGISTER_KERNEL( - relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {} -#endif - -#ifdef PADDLE_WITH_MKLDNN -PD_REGISTER_KERNEL( - relu6, OneDNN, ONEDNN, phi::Relu6Kernel, float, phi::dtype::bfloat16) {} -#endif diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 3896324be79..70c0187e688 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -75,7 +75,6 @@ DECLARE_ACTIVATION_KERNEL(Negative) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) -DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6Raw, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 046cee58578..947ab5da81a 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -109,7 +109,6 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, ThresholdedReluFunctor, threshold) -DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu6Raw, Relu6Functor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) @@ -150,6 +149,17 @@ void SwishKernel(const Context& dev_ctx, ActivationImpl>( dev_ctx, x, out, functor); } + +template +void Relu6Kernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + funcs::Relu6Functor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = 6.0; + ActivationImpl>( + dev_ctx, x, out, functor); +} } // namespace phi PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} @@ -171,7 +181,6 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) -PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) @@ -212,6 +221,7 @@ PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) +PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_KERNEL(log, CPU, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 83e130f0a71..330182286d5 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -126,7 +126,6 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, CudaThresholdedReluFunctor, threshold) -DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu6Raw, CudaRelu6Functor, threshold) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, CudaHardShrinkFunctor, threshold) @@ -176,6 +175,17 @@ void SwishKernel(const Context& dev_ctx, ActivationGPUImpl>( dev_ctx, x, out, functor); } + +template +void Relu6Kernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + funcs::CudaRelu6Functor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = 6.0; + ActivationGPUImpl>( + dev_ctx, x, out, functor); +} } // namespace phi #ifdef PADDLE_WITH_HIP @@ -221,7 +231,7 @@ PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) -PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel) +PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel) diff --git a/paddle/phi/kernels/onednn/activation_kernel.cc b/paddle/phi/kernels/onednn/activation_kernel.cc index 63ee49626a3..a4757eab71c 100644 --- a/paddle/phi/kernels/onednn/activation_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_kernel.cc @@ -178,12 +178,11 @@ void GeluKernel(const Context& dev_ctx, } template -void Relu6RawKernel(const Context& dev_ctx, - const DenseTensor& x, - float threshold, - DenseTensor* out) { +void Relu6Kernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { Relu6OneDNNFunctor functor; - functor(dev_ctx, x, 0, threshold, out); + functor(dev_ctx, x, 0, 6.0, out); } template @@ -210,7 +209,7 @@ PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel) -PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel) +PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) diff --git a/paddle/phi/kernels/sparse/cpu/unary_kernel.cc b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc index d36439549cf..53956174044 100644 --- a/paddle/phi/kernels/sparse/cpu/unary_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc @@ -95,7 +95,6 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(pow, Pow) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(scale, Scale) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6) -PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6_raw, Relu6Raw) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_KERNEL(divide_scalar_coo, diff --git a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu index ba9c3dbf6e3..3b6e84664f9 100644 --- a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu @@ -99,7 +99,6 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(expm1, Expm1) -PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6_raw, Relu6Raw) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6, Relu6) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(leaky_relu, LeakyRelu) diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h index 06d7a433364..5c930dd48b8 100644 --- a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h @@ -89,24 +89,10 @@ DEFINE_SPARSE_UNARY_KERNEL(Log1p) DEFINE_SPARSE_UNARY_KERNEL(Relu) DEFINE_SPARSE_UNARY_KERNEL(Abs) DEFINE_SPARSE_UNARY_KERNEL(Expm1) +DEFINE_SPARSE_UNARY_KERNEL(Relu6) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor) -DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6Raw, threshold) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha) -template -void Relu6CooKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out) { - Relu6RawCooKernel(dev_ctx, x, 6, out); -} - -template -void Relu6CsrKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - SparseCsrTensor* out) { - Relu6RawCsrKernel(dev_ctx, x, 6, out); -} - template void ScaleCooKernel(const Context& dev_ctx, const SparseCooTensor& x, diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h index 0faf0b045ee..24bf4f131f6 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.h +++ b/paddle/phi/kernels/sparse/unary_kernel.h @@ -60,7 +60,6 @@ DECLARE_SPARSE_UNARY_KERNEL(Log1p) DECLARE_SPARSE_UNARY_KERNEL(Abs) DECLARE_SPARSE_UNARY_KERNEL(Expm1) DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor) -DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6Raw, threshold) DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha) template diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 4edbd71a9fc..bd19fc0d9c6 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -415,6 +415,16 @@ void SwishKernel(const Context& dev_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish"); } +template +void Relu6Kernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + XPURelu6Functor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = 6.0; + ActivationXPUImpl>(dev_ctx, x, out, functor); +} + template struct XPUSoftplusFunctor : public funcs::BaseActivationFunctor { using XPUType = typename XPUTypeTrait::Type; @@ -504,10 +514,6 @@ DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, XPUMishFunctor, threshold) DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, XPULeakyReluFunctor, alpha) -DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6Raw, - XPURelu6Functor, - threshold) - DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, XPUSoftplusFunctor, beta, @@ -567,12 +573,8 @@ PD_REGISTER_KERNEL( PD_REGISTER_KERNEL( log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL(relu6_raw, - XPU, - ALL_LAYOUT, - phi::Relu6RawKernel, - float, - phi::dtype::float16) {} +PD_REGISTER_KERNEL( + relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {} #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} diff --git a/test/ir/inference/test_conv_act_onednn_fuse_pass.py b/test/ir/inference/test_conv_act_onednn_fuse_pass.py index faa07dde674..71608f6d603 100755 --- a/test/ir/inference/test_conv_act_onednn_fuse_pass.py +++ b/test/ir/inference/test_conv_act_onednn_fuse_pass.py @@ -166,7 +166,7 @@ class TestConvActOneDNNFusePass(PassAutoScanTest): 'relu6', inputs={'X': ['conv2d_out']}, outputs={'Out': ['relu_out']}, - threshold=draw(st.floats(min_value=1.0, max_value=10.0)), + threshold=6.0, ) elif act_type == 'leaky_relu': act_op = OpConfig( diff --git a/test/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py b/test/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py index 85533734a1c..95b4b0613bd 100644 --- a/test/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py +++ b/test/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py @@ -86,7 +86,7 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): activation_type, inputs={"X": ["matmul_output"]}, outputs={"Out": ["activation_output"]}, - threshold=draw(st.floats(min_value=1.0, max_value=10.0)), + threshold=6, ) elif activation_type == "leaky_relu": activation_op = OpConfig( diff --git a/test/ir/inference/test_mkldnn_matmul_elementwise_add_activation_fuse_pass.py b/test/ir/inference/test_mkldnn_matmul_elementwise_add_activation_fuse_pass.py index 19592b91acf..ef560edbd18 100644 --- a/test/ir/inference/test_mkldnn_matmul_elementwise_add_activation_fuse_pass.py +++ b/test/ir/inference/test_mkldnn_matmul_elementwise_add_activation_fuse_pass.py @@ -81,7 +81,7 @@ class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest): activation_type, inputs={"X": ["elementwise_add_output"]}, outputs={"Out": ["activation_output"]}, - threshold=draw(st.floats(min_value=1.0, max_value=10.0)), + threshold=6.0, ) elif activation_type == "leaky_relu": activation_op = OpConfig( diff --git a/test/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py b/test/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py index 57403760bd9..cf20cf43ec3 100644 --- a/test/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py +++ b/test/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py @@ -90,7 +90,7 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): activation_type, inputs={'X': ['matmul_output']}, outputs={'Out': ['activation_output']}, - threshold=draw(st.floats(min_value=1.0, max_value=10.0)), + threshold=6.0, ) elif activation_type == "leaky_relu": activation_op = OpConfig( diff --git a/test/ir/inference/test_onednn_conv_concat_activation_fuse_pass.py b/test/ir/inference/test_onednn_conv_concat_activation_fuse_pass.py index ca8648d9a34..1a71841d22c 100644 --- a/test/ir/inference/test_onednn_conv_concat_activation_fuse_pass.py +++ b/test/ir/inference/test_onednn_conv_concat_activation_fuse_pass.py @@ -99,7 +99,7 @@ class TestOneDNNConvConcatActivationFusePass(PassAutoScanTest): activation_type, inputs={'X': ['concat_output']}, outputs={'Out': ['activation_output']}, - threshold=draw(st.floats(min_value=1.0, max_value=10.0)), + threshold=6.0, ) elif activation_type == 'leaky_relu': activation_op = OpConfig( diff --git a/test/ir/inference/test_onednn_elementwise_add_activation_fuse_pass.py b/test/ir/inference/test_onednn_elementwise_add_activation_fuse_pass.py index 9047148e8b4..3d396968a76 100644 --- a/test/ir/inference/test_onednn_elementwise_add_activation_fuse_pass.py +++ b/test/ir/inference/test_onednn_elementwise_add_activation_fuse_pass.py @@ -62,7 +62,7 @@ class TestElementwiseAddActivationOneDNNFusePass(PassAutoScanTest): activation_type, inputs={'X': ['eltwise_output']}, outputs={'Out': ['activation_output']}, - threshold=draw(st.floats(min_value=1.0, max_value=10.0)), + threshold=6.0, ) elif activation_type == "leaky_relu": activation_op = OpConfig( diff --git a/test/ir/inference/test_onednn_softplus_activation_fuse_pass.py b/test/ir/inference/test_onednn_softplus_activation_fuse_pass.py index 2f15d8a43c6..4a8e8604480 100644 --- a/test/ir/inference/test_onednn_softplus_activation_fuse_pass.py +++ b/test/ir/inference/test_onednn_softplus_activation_fuse_pass.py @@ -85,7 +85,7 @@ class TestSoftplusActivationOneDNNFusePass(PassAutoScanTest): activation_type, inputs={'X': ['softplus_out']}, outputs={'Out': ['activation_output']}, - threshold=draw(st.floats(min_value=1.0, max_value=10.0)), + threshold=6.0, ) elif activation_type == 'swish': activation_op = OpConfig( -- GitLab