diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index 10a072b5623f9d533274521b5cfe52c0ae2adc44..95404bbd4a8a7a71cf67990d8f987d22d0dcdbce 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -466,7 +466,9 @@ class InterpolateV2Op : public framework::OperatorWithKernel { } } #endif - if (var_name == "SizeTensor" || var_name == "Scale") { + + if (var_name == "OutSize" || var_name == "SizeTensor" || + var_name == "Scale") { return expected_kernel_type; } return framework::OpKernelType( @@ -701,7 +703,8 @@ class InterpolateV2OpGrad : public framework::OperatorWithKernel { const std::string& var_name, const phi::DenseTensor& tensor, const framework::OpKernelType& expected_kernel_type) const override { - if (var_name == "SizeTensor" || var_name == "Scale") { + if (var_name == "OutSize" || var_name == "SizeTensor" || + var_name == "Scale") { return expected_kernel_type; } return framework::OpKernelType( diff --git a/paddle/phi/kernels/gpu/interpolate_kernel.cu b/paddle/phi/kernels/gpu/interpolate_kernel.cu index 07e113ef7aa8004284b35ecf14345c4dc0491261..8ca24b3e4f05de66759e7c9c054b19119cf72168 100644 --- a/paddle/phi/kernels/gpu/interpolate_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_kernel.cu @@ -1458,6 +1458,7 @@ PD_REGISTER_KERNEL(bilinear_interp, double, phi::dtype::float16, int) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1471,6 +1472,7 @@ PD_REGISTER_KERNEL(nearest_interp, phi::dtype::bfloat16, int, int64_t) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1482,6 +1484,7 @@ PD_REGISTER_KERNEL(trilinear_interp, double, phi::dtype::float16, int) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1493,6 +1496,7 @@ PD_REGISTER_KERNEL(linear_interp, double, phi::dtype::float16, int) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1504,6 +1508,7 @@ PD_REGISTER_KERNEL(bicubic_interp, double, phi::dtype::float16, int) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); }