From fd9c555c0dbc2593f54e5238d10930f45defda27 Mon Sep 17 00:00:00 2001 From: wz1qqx <55830058+wz1qqx@users.noreply.github.com> Date: Thu, 8 Jun 2023 15:01:10 +0800 Subject: [PATCH] [XPU]add fp16 kernels (#54410) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 6 ++++-- paddle/phi/kernels/activation_kernel.cc | 3 ++- paddle/phi/kernels/xpu/activation_kernel.cc | 8 +++++++- paddle/phi/kernels/xpu/clip_kernel.cc | 17 +++++++++++++---- paddle/phi/kernels/xpu/conv_kernel.cc | 8 ++++++-- 5 files changed, 32 insertions(+), 10 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index add1d7eca7d..7a287c99339 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -140,6 +140,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"clip", XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, phi::DataType::INT64, phi::DataType::INT32})}, {"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -188,7 +189,8 @@ XPUOpMap& get_kl2_ops() { {"deformable_conv_v1_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"deformable_conv_v1", XPUKernelSet({phi::DataType::FLOAT32})}, {"depthwise_conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"depthwise_conv2d", XPUKernelSet({phi::DataType::FLOAT32})}, + {"depthwise_conv2d", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"depthwise_conv2d_transpose_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"depthwise_conv2d_transpose", @@ -599,7 +601,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32, phi::DataType::INT8, phi::DataType::FLOAT32})}, - {"relu6", XPUKernelSet({phi::DataType::FLOAT32})}, + {"relu6", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"relu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, diff --git a/paddle/phi/kernels/activation_kernel.cc b/paddle/phi/kernels/activation_kernel.cc index ab367a67269..068fd9b575a 100644 --- a/paddle/phi/kernels/activation_kernel.cc +++ b/paddle/phi/kernels/activation_kernel.cc @@ -62,7 +62,8 @@ PD_REGISTER_KERNEL(swish, #endif #if defined PADDLE_WITH_XPU -PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {} +PD_REGISTER_KERNEL( + relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {} PD_REGISTER_KERNEL( swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {} #endif diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 1d1f33204bc..dd8d483a8b5 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -572,6 +572,13 @@ PD_REGISTER_KERNEL( PD_REGISTER_KERNEL( log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(relu6_raw, + XPU, + ALL_LAYOUT, + phi::Relu6RawKernel, + float, + phi::dtype::float16) {} + #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} @@ -581,7 +588,6 @@ PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) -PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel) PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel) diff --git a/paddle/phi/kernels/xpu/clip_kernel.cc b/paddle/phi/kernels/xpu/clip_kernel.cc index 820c85f7ea9..8b01c06c245 100644 --- a/paddle/phi/kernels/xpu/clip_kernel.cc +++ b/paddle/phi/kernels/xpu/clip_kernel.cc @@ -13,6 +13,9 @@ // limitations under the License. #include "paddle/phi/kernels/clip_kernel.h" + +#include "glog/logging.h" + #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/xpu_header.h" #include "paddle/phi/core/kernel_registry.h" @@ -33,8 +36,8 @@ void ClipKernel(const Context& dev_ctx, x_data, out_data, x.numel(), - min.to(), - max.to()); + static_cast(min.to()), + static_cast(max.to())); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, @@ -46,5 +49,11 @@ void ClipKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - clip, XPU, ALL_LAYOUT, phi::ClipKernel, float, int64_t, int) {} +PD_REGISTER_KERNEL(clip, + XPU, + ALL_LAYOUT, + phi::ClipKernel, + float, + phi::dtype::float16, + int64_t, + int) {} diff --git a/paddle/phi/kernels/xpu/conv_kernel.cc b/paddle/phi/kernels/xpu/conv_kernel.cc index 7f242680414..e8148602d13 100644 --- a/paddle/phi/kernels/xpu/conv_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_kernel.cc @@ -310,7 +310,11 @@ void Conv3DKernel(const Context& dev_ctx, PD_REGISTER_KERNEL( conv2d, XPU, ALL_LAYOUT, phi::ConvKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - depthwise_conv2d, XPU, ALL_LAYOUT, phi::DepthwiseConvKernel, float) {} +PD_REGISTER_KERNEL(depthwise_conv2d, + XPU, + ALL_LAYOUT, + phi::DepthwiseConvKernel, + float, + phi::dtype::float16) {} PD_REGISTER_KERNEL( conv3d, XPU, ALL_LAYOUT, phi::Conv3DKernel, float, phi::dtype::float16) {} -- GitLab