From ae5445860efc7dda11f257d532c9ff1a1d9ca185 Mon Sep 17 00:00:00 2001 From: haosicheng <47998305+HarperCy@users.noreply.github.com> Date: Fri, 23 Dec 2022 14:22:40 +0800 Subject: [PATCH] square_grad support fp16 *test=kunlun (#48847) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 3 ++- paddle/phi/kernels/xpu/activation_grad_kernel.cc | 9 ++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 3a7a0f2fd6b..fef881ffd24 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -479,7 +479,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32})}, {"sqrt", XPUKernelSet({phi::DataType::FLOAT32})}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"square_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"square_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"square", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"squeeze2_grad", diff --git a/paddle/phi/kernels/xpu/activation_grad_kernel.cc b/paddle/phi/kernels/xpu/activation_grad_kernel.cc index 4ab540a5705..3c3e16e0eb2 100644 --- a/paddle/phi/kernels/xpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_grad_kernel.cc @@ -640,6 +640,14 @@ PD_REGISTER_KERNEL(tanh_grad, phi::TanhGradKernel, float, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(square_grad, + XPU, + ALL_LAYOUT, + phi::SquareGradKernel, + float, + phi::dtype::float16) {} + PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) @@ -652,5 +660,4 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(square_grad, SquareGradKernel) PD_REGISTER_KERNEL(pow_grad, XPU, ALL_LAYOUT, phi::PowGradKernel, float) {} -- GitLab