未验证 提交 c6bf8812 编写于 作者: Z zyfncg 提交者: GitHub

fix data transform bug of interpolate op (#44401)

上级 b2224e6f
......@@ -1041,28 +1041,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
ALL_LAYOUT,
phi::BilinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::NearestInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::TrilinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::LinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::BicubicInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -1193,7 +1193,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2,
phi::BilinearInterpKernel,
float,
double,
uint8_t) {}
uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2,
CPU,
ALL_LAYOUT,
......@@ -1202,24 +1205,36 @@ PD_REGISTER_KERNEL(nearest_interp_v2,
double,
int,
int64_t,
uint8_t) {}
uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2,
CPU,
ALL_LAYOUT,
phi::TrilinearInterpKernel,
float,
double,
uint8_t) {}
uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2,
CPU,
ALL_LAYOUT,
phi::LinearInterpKernel,
float,
double,
uint8_t) {}
uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2,
CPU,
ALL_LAYOUT,
phi::BicubicInterpKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -1574,28 +1574,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
ALL_LAYOUT,
phi::BilinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::NearestInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::TrilinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::LinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::BicubicInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -1446,7 +1446,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2,
phi::BilinearInterpKernel,
float,
double,
int) {}
int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2,
GPU,
ALL_LAYOUT,
......@@ -1454,25 +1457,37 @@ PD_REGISTER_KERNEL(nearest_interp_v2,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2,
GPU,
ALL_LAYOUT,
phi::TrilinearInterpKernel,
float,
double,
int) {}
int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2,
GPU,
ALL_LAYOUT,
phi::LinearInterpKernel,
float,
double,
int) {}
int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2,
GPU,
ALL_LAYOUT,
phi::BicubicInterpKernel,
float,
double,
int) {}
int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册