From c6bf88127e2e53f144165496e44204b7d8fd80e2 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 18 Jul 2022 20:05:26 +0800 Subject: [PATCH] fix data transform bug of interpolate op (#44401) --- .../kernels/cpu/interpolate_grad_kernel.cc | 25 +++++++++++++++---- paddle/phi/kernels/cpu/interpolate_kernel.cc | 25 +++++++++++++++---- .../kernels/gpu/interpolate_grad_kernel.cu | 25 +++++++++++++++---- paddle/phi/kernels/gpu/interpolate_kernel.cu | 25 +++++++++++++++---- 4 files changed, 80 insertions(+), 20 deletions(-) diff --git a/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc b/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc index edd41b2c7a..dee6e9149c 100644 --- a/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc @@ -1041,28 +1041,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad, ALL_LAYOUT, phi::BilinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(nearest_interp_v2_grad, CPU, ALL_LAYOUT, phi::NearestInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(trilinear_interp_v2_grad, CPU, ALL_LAYOUT, phi::TrilinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(linear_interp_v2_grad, CPU, ALL_LAYOUT, phi::LinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(bicubic_interp_v2_grad, CPU, ALL_LAYOUT, phi::BicubicInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/cpu/interpolate_kernel.cc b/paddle/phi/kernels/cpu/interpolate_kernel.cc index 5259a77056..3649185a0c 100644 --- a/paddle/phi/kernels/cpu/interpolate_kernel.cc +++ b/paddle/phi/kernels/cpu/interpolate_kernel.cc @@ -1193,7 +1193,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2, phi::BilinearInterpKernel, float, double, - uint8_t) {} + uint8_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(nearest_interp_v2, CPU, ALL_LAYOUT, @@ -1202,24 +1205,36 @@ PD_REGISTER_KERNEL(nearest_interp_v2, double, int, int64_t, - uint8_t) {} + uint8_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(trilinear_interp_v2, CPU, ALL_LAYOUT, phi::TrilinearInterpKernel, float, double, - uint8_t) {} + uint8_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(linear_interp_v2, CPU, ALL_LAYOUT, phi::LinearInterpKernel, float, double, - uint8_t) {} + uint8_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(bicubic_interp_v2, CPU, ALL_LAYOUT, phi::BicubicInterpKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu index 175f09fccf..047b4ff69a 100644 --- a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu @@ -1574,28 +1574,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad, ALL_LAYOUT, phi::BilinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(nearest_interp_v2_grad, GPU, ALL_LAYOUT, phi::NearestInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(trilinear_interp_v2_grad, GPU, ALL_LAYOUT, phi::TrilinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(linear_interp_v2_grad, GPU, ALL_LAYOUT, phi::LinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(bicubic_interp_v2_grad, GPU, ALL_LAYOUT, phi::BicubicInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/gpu/interpolate_kernel.cu b/paddle/phi/kernels/gpu/interpolate_kernel.cu index 7bc331c52a..c05514236e 100644 --- a/paddle/phi/kernels/gpu/interpolate_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_kernel.cu @@ -1446,7 +1446,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2, phi::BilinearInterpKernel, float, double, - int) {} + int) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(nearest_interp_v2, GPU, ALL_LAYOUT, @@ -1454,25 +1457,37 @@ PD_REGISTER_KERNEL(nearest_interp_v2, float, double, int, - int64_t) {} + int64_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(trilinear_interp_v2, GPU, ALL_LAYOUT, phi::TrilinearInterpKernel, float, double, - int) {} + int) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(linear_interp_v2, GPU, ALL_LAYOUT, phi::LinearInterpKernel, float, double, - int) {} + int) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(bicubic_interp_v2, GPU, ALL_LAYOUT, phi::BicubicInterpKernel, float, double, - int) {} + int) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} -- GitLab