未验证 提交 81d0a3cc 编写于 作者: H haosicheng 提交者: GitHub

add square fp16 *test=kunlun (#48095)

上级 8d00f76e
......@@ -568,7 +568,9 @@ XPUOpMap& get_kl2_ops() {
{"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"squeeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
......
......@@ -479,6 +479,9 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
square, XPU, ALL_LAYOUT, phi::SquareKernel, float, phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
......@@ -492,4 +495,3 @@ PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(square, SquareKernel)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册