diff --git a/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc b/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc index edd41b2c7a31d084011da7ec6e026135c5588cb3..dee6e9149ca2d4a894c448ff8fe76a38af648efe 100644 --- a/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc @@ -1041,28 +1041,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad, ALL_LAYOUT, phi::BilinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(nearest_interp_v2_grad, CPU, ALL_LAYOUT, phi::NearestInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(trilinear_interp_v2_grad, CPU, ALL_LAYOUT, phi::TrilinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(linear_interp_v2_grad, CPU, ALL_LAYOUT, phi::LinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(bicubic_interp_v2_grad, CPU, ALL_LAYOUT, phi::BicubicInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/cpu/interpolate_kernel.cc b/paddle/phi/kernels/cpu/interpolate_kernel.cc index 5259a770568e4d9ccf3755cfb42cac0f2d124d39..3649185a0c7ee18dd302c0b6638dd3eb752fadd8 100644 --- a/paddle/phi/kernels/cpu/interpolate_kernel.cc +++ b/paddle/phi/kernels/cpu/interpolate_kernel.cc @@ -1193,7 +1193,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2, phi::BilinearInterpKernel, float, double, - uint8_t) {} + uint8_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(nearest_interp_v2, CPU, ALL_LAYOUT, @@ -1202,24 +1205,36 @@ PD_REGISTER_KERNEL(nearest_interp_v2, double, int, int64_t, - uint8_t) {} + uint8_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(trilinear_interp_v2, CPU, ALL_LAYOUT, phi::TrilinearInterpKernel, float, double, - uint8_t) {} + uint8_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(linear_interp_v2, CPU, ALL_LAYOUT, phi::LinearInterpKernel, float, double, - uint8_t) {} + uint8_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(bicubic_interp_v2, CPU, ALL_LAYOUT, phi::BicubicInterpKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu index 175f09fccfa300fee21765d2fcb729aa251adf33..047b4ff69a784011f0cc7a87edd0a1bb2ea15d8f 100644 --- a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu @@ -1574,28 +1574,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad, ALL_LAYOUT, phi::BilinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(nearest_interp_v2_grad, GPU, ALL_LAYOUT, phi::NearestInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(trilinear_interp_v2_grad, GPU, ALL_LAYOUT, phi::TrilinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(linear_interp_v2_grad, GPU, ALL_LAYOUT, phi::LinearInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(bicubic_interp_v2_grad, GPU, ALL_LAYOUT, phi::BicubicInterpGradKernel, float, - double) {} + double) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/gpu/interpolate_kernel.cu b/paddle/phi/kernels/gpu/interpolate_kernel.cu index 7bc331c52a015d41118a87c46e79986c9d1b4426..c05514236e091617c0a5f2d01c34be15621d6885 100644 --- a/paddle/phi/kernels/gpu/interpolate_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_kernel.cu @@ -1446,7 +1446,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2, phi::BilinearInterpKernel, float, double, - int) {} + int) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(nearest_interp_v2, GPU, ALL_LAYOUT, @@ -1454,25 +1457,37 @@ PD_REGISTER_KERNEL(nearest_interp_v2, float, double, int, - int64_t) {} + int64_t) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(trilinear_interp_v2, GPU, ALL_LAYOUT, phi::TrilinearInterpKernel, float, double, - int) {} + int) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(linear_interp_v2, GPU, ALL_LAYOUT, phi::LinearInterpKernel, float, double, - int) {} + int) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(bicubic_interp_v2, GPU, ALL_LAYOUT, phi::BicubicInterpKernel, float, double, - int) {} + int) { + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); +}