未验证 提交 0b2a66bb 编写于 作者: A Aurelius84 提交者: GitHub

[Perf]Fix interploate OutSize data transform problem (#48498)

* [Perf]Fix interploate OutSize data transform problem

* fix code style

* fix grad

* fix phi kernel
上级 8a717a3e
......@@ -466,7 +466,9 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
}
}
#endif
if (var_name == "SizeTensor" || var_name == "Scale") {
if (var_name == "OutSize" || var_name == "SizeTensor" ||
var_name == "Scale") {
return expected_kernel_type;
}
return framework::OpKernelType(
......@@ -701,7 +703,8 @@ class InterpolateV2OpGrad : public framework::OperatorWithKernel {
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "SizeTensor" || var_name == "Scale") {
if (var_name == "OutSize" || var_name == "SizeTensor" ||
var_name == "Scale") {
return expected_kernel_type;
}
return framework::OpKernelType(
......
......@@ -1458,6 +1458,7 @@ PD_REGISTER_KERNEL(bilinear_interp,
double,
phi::dtype::float16,
int) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -1471,6 +1472,7 @@ PD_REGISTER_KERNEL(nearest_interp,
phi::dtype::bfloat16,
int,
int64_t) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -1482,6 +1484,7 @@ PD_REGISTER_KERNEL(trilinear_interp,
double,
phi::dtype::float16,
int) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -1493,6 +1496,7 @@ PD_REGISTER_KERNEL(linear_interp,
double,
phi::dtype::float16,
int) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -1504,6 +1508,7 @@ PD_REGISTER_KERNEL(bicubic_interp,
double,
phi::dtype::float16,
int) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册