diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 023cfa33befc9eb5dc8625ea8050984d1e8ad1d7..438034e3645fc08409f4dbd40d8bbcf985236db6 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -75,7 +75,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"bilinear_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, + {"bilinear_interp_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"bitwise_not", XPUKernelSet({phi::DataType::BOOL})}, {"broadcast", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -496,7 +497,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT64})}, {"multi_encoder_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, + {"nearest_interp_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"not_equal", XPUKernelSet({phi::DataType::INT64, diff --git a/paddle/phi/kernels/xpu/interpolate_kernel.cc b/paddle/phi/kernels/xpu/interpolate_kernel.cc index 091a8164ea4ec94016126b8635571197ac08db02..43d732930152673e0af496d3180a12dc3d3a7e6c 100644 --- a/paddle/phi/kernels/xpu/interpolate_kernel.cc +++ b/paddle/phi/kernels/xpu/interpolate_kernel.cc @@ -38,6 +38,7 @@ void InterpolateKernel( bool align_corners, int align_mode, DenseTensor* output) { + using XPUType = typename XPUTypeTrait::Type; const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); int n, c, in_d, in_h, in_w; phi::funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); @@ -140,18 +141,19 @@ void InterpolateKernel( errors::InvalidArgument("XPU nearest is only support NCHW")); } - int r = xpu::interpolate2d(ctx.x_context(), - x.data(), - output->data(), - n, - c, - in_h, - in_w, - out_h, - out_w, - nearest, - trans_mode, - (data_layout == DataLayout::kNCHW)); + int r = + xpu::interpolate2d(ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(output->data()), + n, + c, + in_h, + in_w, + out_h, + out_w, + nearest, + trans_mode, + (data_layout == DataLayout::kNCHW)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "interpolate2d"); } @@ -221,14 +223,22 @@ void NearestInterpKernel( } // namespace phi -PD_REGISTER_KERNEL( - bilinear_interp, XPU, ALL_LAYOUT, phi::BilinearInterpKernel, float) { +PD_REGISTER_KERNEL(bilinear_interp, + XPU, + ALL_LAYOUT, + phi::BilinearInterpKernel, + phi::dtype::float16, + float) { kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } -PD_REGISTER_KERNEL( - nearest_interp, XPU, ALL_LAYOUT, phi::NearestInterpKernel, float) { +PD_REGISTER_KERNEL(nearest_interp, + XPU, + ALL_LAYOUT, + phi::NearestInterpKernel, + phi::dtype::float16, + float) { kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);