未验证 提交 0b2a66bb 编写于 作者: A Aurelius84 提交者: GitHub

[Perf]Fix interploate OutSize data transform problem (#48498)

* [Perf]Fix interploate OutSize data transform problem

* fix code style

* fix grad

* fix phi kernel
上级 8a717a3e
...@@ -466,7 +466,9 @@ class InterpolateV2Op : public framework::OperatorWithKernel { ...@@ -466,7 +466,9 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
} }
} }
#endif #endif
if (var_name == "SizeTensor" || var_name == "Scale") {
if (var_name == "OutSize" || var_name == "SizeTensor" ||
var_name == "Scale") {
return expected_kernel_type; return expected_kernel_type;
} }
return framework::OpKernelType( return framework::OpKernelType(
...@@ -701,7 +703,8 @@ class InterpolateV2OpGrad : public framework::OperatorWithKernel { ...@@ -701,7 +703,8 @@ class InterpolateV2OpGrad : public framework::OperatorWithKernel {
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override { const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "SizeTensor" || var_name == "Scale") { if (var_name == "OutSize" || var_name == "SizeTensor" ||
var_name == "Scale") {
return expected_kernel_type; return expected_kernel_type;
} }
return framework::OpKernelType( return framework::OpKernelType(
......
...@@ -1458,6 +1458,7 @@ PD_REGISTER_KERNEL(bilinear_interp, ...@@ -1458,6 +1458,7 @@ PD_REGISTER_KERNEL(bilinear_interp,
double, double,
phi::dtype::float16, phi::dtype::float16,
int) { int) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -1471,6 +1472,7 @@ PD_REGISTER_KERNEL(nearest_interp, ...@@ -1471,6 +1472,7 @@ PD_REGISTER_KERNEL(nearest_interp,
phi::dtype::bfloat16, phi::dtype::bfloat16,
int, int,
int64_t) { int64_t) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -1482,6 +1484,7 @@ PD_REGISTER_KERNEL(trilinear_interp, ...@@ -1482,6 +1484,7 @@ PD_REGISTER_KERNEL(trilinear_interp,
double, double,
phi::dtype::float16, phi::dtype::float16,
int) { int) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -1493,6 +1496,7 @@ PD_REGISTER_KERNEL(linear_interp, ...@@ -1493,6 +1496,7 @@ PD_REGISTER_KERNEL(linear_interp,
double, double,
phi::dtype::float16, phi::dtype::float16,
int) { int) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -1504,6 +1508,7 @@ PD_REGISTER_KERNEL(bicubic_interp, ...@@ -1504,6 +1508,7 @@ PD_REGISTER_KERNEL(bicubic_interp,
double, double,
phi::dtype::float16, phi::dtype::float16,
int) { int) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册