From 0b2a66bbf6610308c74d8d0dd78ce702b2febfbf Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 30 Nov 2022 14:01:24 +0800 Subject: [PATCH] [Perf]Fix interploate OutSize data transform problem (#48498) * [Perf]Fix interploate OutSize data transform problem * fix code style * fix grad * fix phi kernel --- paddle/fluid/operators/interpolate_v2_op.cc | 7 +++++-- paddle/phi/kernels/gpu/interpolate_kernel.cu | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index 10a072b562..95404bbd4a 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -466,7 +466,9 @@ class InterpolateV2Op : public framework::OperatorWithKernel { } } #endif - if (var_name == "SizeTensor" || var_name == "Scale") { + + if (var_name == "OutSize" || var_name == "SizeTensor" || + var_name == "Scale") { return expected_kernel_type; } return framework::OpKernelType( @@ -701,7 +703,8 @@ class InterpolateV2OpGrad : public framework::OperatorWithKernel { const std::string& var_name, const phi::DenseTensor& tensor, const framework::OpKernelType& expected_kernel_type) const override { - if (var_name == "SizeTensor" || var_name == "Scale") { + if (var_name == "OutSize" || var_name == "SizeTensor" || + var_name == "Scale") { return expected_kernel_type; } return framework::OpKernelType( diff --git a/paddle/phi/kernels/gpu/interpolate_kernel.cu b/paddle/phi/kernels/gpu/interpolate_kernel.cu index 07e113ef7a..8ca24b3e4f 100644 --- a/paddle/phi/kernels/gpu/interpolate_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_kernel.cu @@ -1458,6 +1458,7 @@ PD_REGISTER_KERNEL(bilinear_interp, double, phi::dtype::float16, int) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1471,6 +1472,7 @@ PD_REGISTER_KERNEL(nearest_interp, phi::dtype::bfloat16, int, int64_t) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1482,6 +1484,7 @@ PD_REGISTER_KERNEL(trilinear_interp, double, phi::dtype::float16, int) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1493,6 +1496,7 @@ PD_REGISTER_KERNEL(linear_interp, double, phi::dtype::float16, int) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1504,6 +1508,7 @@ PD_REGISTER_KERNEL(bicubic_interp, double, phi::dtype::float16, int) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } -- GitLab