From 3996f0de6514890181dbf78f95c2bbb39108d04f Mon Sep 17 00:00:00 2001 From: csy0225 <78470701+csy0225@users.noreply.github.com> Date: Fri, 31 Mar 2023 15:40:36 +0800 Subject: [PATCH] [XPU] interpolate support fp16 (#52358) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 6 ++- paddle/phi/kernels/xpu/interpolate_kernel.cc | 42 ++++++++++++-------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 023cfa33bef..438034e3645 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -75,7 +75,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"bilinear_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, + {"bilinear_interp_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"bitwise_not", XPUKernelSet({phi::DataType::BOOL})}, {"broadcast", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -496,7 +497,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT64})}, {"multi_encoder_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, + {"nearest_interp_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"not_equal", XPUKernelSet({phi::DataType::INT64, diff --git a/paddle/phi/kernels/xpu/interpolate_kernel.cc b/paddle/phi/kernels/xpu/interpolate_kernel.cc index 091a8164ea4..43d73293015 100644 --- a/paddle/phi/kernels/xpu/interpolate_kernel.cc +++ b/paddle/phi/kernels/xpu/interpolate_kernel.cc @@ -38,6 +38,7 @@ void InterpolateKernel( bool align_corners, int align_mode, DenseTensor* output) { + using XPUType = typename XPUTypeTrait::Type; const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); int n, c, in_d, in_h, in_w; phi::funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); @@ -140,18 +141,19 @@ void InterpolateKernel( errors::InvalidArgument("XPU nearest is only support NCHW")); } - int r = xpu::interpolate2d(ctx.x_context(), - x.data(), - output->data(), - n, - c, - in_h, - in_w, - out_h, - out_w, - nearest, - trans_mode, - (data_layout == DataLayout::kNCHW)); + int r = + xpu::interpolate2d(ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(output->data()), + n, + c, + in_h, + in_w, + out_h, + out_w, + nearest, + trans_mode, + (data_layout == DataLayout::kNCHW)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "interpolate2d"); } @@ -221,14 +223,22 @@ void NearestInterpKernel( } // namespace phi -PD_REGISTER_KERNEL( - bilinear_interp, XPU, ALL_LAYOUT, phi::BilinearInterpKernel, float) { +PD_REGISTER_KERNEL(bilinear_interp, + XPU, + ALL_LAYOUT, + phi::BilinearInterpKernel, + phi::dtype::float16, + float) { kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } -PD_REGISTER_KERNEL( - nearest_interp, XPU, ALL_LAYOUT, phi::NearestInterpKernel, float) { +PD_REGISTER_KERNEL(nearest_interp, + XPU, + ALL_LAYOUT, + phi::NearestInterpKernel, + phi::dtype::float16, + float) { kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); -- GitLab