From 81d0a3cc453d8dcc95484d822e7fd3df68e6abac Mon Sep 17 00:00:00 2001 From: haosicheng <47998305+HarperCy@users.noreply.github.com> Date: Mon, 28 Nov 2022 10:17:00 +0800 Subject: [PATCH] add square fp16 *test=kunlun (#48095) --- paddle/fluid/platform/device/xpu/xpu2_op_list.h | 4 +++- paddle/phi/kernels/xpu/activation_kernel.cc | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index f8b15d4d4e..b1838a0f71 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -568,7 +568,9 @@ XPUOpMap& get_kl2_ops() { {"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"square_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"square", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"squeeze2_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 51f74bd347..2425f304a3 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -479,6 +479,9 @@ PD_REGISTER_KERNEL( PD_REGISTER_KERNEL( tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL( + square, XPU, ALL_LAYOUT, phi::SquareKernel, float, phi::dtype::float16) {} + PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) @@ -492,4 +495,3 @@ PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) -PD_REGISTER_ACTIVATION_KERNEL(square, SquareKernel) -- GitLab