未验证 提交 ae544586 编写于 作者: H haosicheng 提交者: GitHub

square_grad support fp16 *test=kunlun (#48847)

上级 c088f9ec
...@@ -479,7 +479,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -479,7 +479,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32})}, phi::DataType::INT32})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32})}, {"sqrt", XPUKernelSet({phi::DataType::FLOAT32})},
{"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"square_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"square_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"square", {"square",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"squeeze2_grad", {"squeeze2_grad",
......
...@@ -640,6 +640,14 @@ PD_REGISTER_KERNEL(tanh_grad, ...@@ -640,6 +640,14 @@ PD_REGISTER_KERNEL(tanh_grad,
phi::TanhGradKernel, phi::TanhGradKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_KERNEL(square_grad,
XPU,
ALL_LAYOUT,
phi::SquareGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
...@@ -652,5 +660,4 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) ...@@ -652,5 +660,4 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(square_grad, SquareGradKernel)
PD_REGISTER_KERNEL(pow_grad, XPU, ALL_LAYOUT, phi::PowGradKernel, float) {} PD_REGISTER_KERNEL(pow_grad, XPU, ALL_LAYOUT, phi::PowGradKernel, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册