diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 3a7a0f2fd6b5636f10f6d72be58ddcd29e850b4c..fef881ffd249d398664e507dc4a390bff2c20c77 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -479,7 +479,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32})}, {"sqrt", XPUKernelSet({phi::DataType::FLOAT32})}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"square_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"square_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"square", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"squeeze2_grad", diff --git a/paddle/phi/kernels/xpu/activation_grad_kernel.cc b/paddle/phi/kernels/xpu/activation_grad_kernel.cc index 4ab540a570577a1edb31332594f50fb266cf8b41..3c3e16e0eb2e072ad8c8903842965b567265b8ac 100644 --- a/paddle/phi/kernels/xpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_grad_kernel.cc @@ -640,6 +640,14 @@ PD_REGISTER_KERNEL(tanh_grad, phi::TanhGradKernel, float, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(square_grad, + XPU, + ALL_LAYOUT, + phi::SquareGradKernel, + float, + phi::dtype::float16) {} + PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) @@ -652,5 +660,4 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(square_grad, SquareGradKernel) PD_REGISTER_KERNEL(pow_grad, XPU, ALL_LAYOUT, phi::PowGradKernel, float) {}