未验证 提交 c6bf8812 编写于 作者: Z zyfncg 提交者: GitHub

fix data transform bug of interpolate op (#44401)

上级 b2224e6f
...@@ -1041,28 +1041,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad, ...@@ -1041,28 +1041,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::BilinearInterpGradKernel, phi::BilinearInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2_grad, PD_REGISTER_KERNEL(nearest_interp_v2_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::NearestInterpGradKernel, phi::NearestInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2_grad, PD_REGISTER_KERNEL(trilinear_interp_v2_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::TrilinearInterpGradKernel, phi::TrilinearInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2_grad, PD_REGISTER_KERNEL(linear_interp_v2_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::LinearInterpGradKernel, phi::LinearInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2_grad, PD_REGISTER_KERNEL(bicubic_interp_v2_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::BicubicInterpGradKernel, phi::BicubicInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -1193,7 +1193,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2, ...@@ -1193,7 +1193,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2,
phi::BilinearInterpKernel, phi::BilinearInterpKernel,
float, float,
double, double,
uint8_t) {} uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2, PD_REGISTER_KERNEL(nearest_interp_v2,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -1202,24 +1205,36 @@ PD_REGISTER_KERNEL(nearest_interp_v2, ...@@ -1202,24 +1205,36 @@ PD_REGISTER_KERNEL(nearest_interp_v2,
double, double,
int, int,
int64_t, int64_t,
uint8_t) {} uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2, PD_REGISTER_KERNEL(trilinear_interp_v2,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::TrilinearInterpKernel, phi::TrilinearInterpKernel,
float, float,
double, double,
uint8_t) {} uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2, PD_REGISTER_KERNEL(linear_interp_v2,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::LinearInterpKernel, phi::LinearInterpKernel,
float, float,
double, double,
uint8_t) {} uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2, PD_REGISTER_KERNEL(bicubic_interp_v2,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::BicubicInterpKernel, phi::BicubicInterpKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -1574,28 +1574,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad, ...@@ -1574,28 +1574,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::BilinearInterpGradKernel, phi::BilinearInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2_grad, PD_REGISTER_KERNEL(nearest_interp_v2_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::NearestInterpGradKernel, phi::NearestInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2_grad, PD_REGISTER_KERNEL(trilinear_interp_v2_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::TrilinearInterpGradKernel, phi::TrilinearInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2_grad, PD_REGISTER_KERNEL(linear_interp_v2_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::LinearInterpGradKernel, phi::LinearInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2_grad, PD_REGISTER_KERNEL(bicubic_interp_v2_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::BicubicInterpGradKernel, phi::BicubicInterpGradKernel,
float, float,
double) {} double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -1446,7 +1446,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2, ...@@ -1446,7 +1446,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2,
phi::BilinearInterpKernel, phi::BilinearInterpKernel,
float, float,
double, double,
int) {} int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2, PD_REGISTER_KERNEL(nearest_interp_v2,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -1454,25 +1457,37 @@ PD_REGISTER_KERNEL(nearest_interp_v2, ...@@ -1454,25 +1457,37 @@ PD_REGISTER_KERNEL(nearest_interp_v2,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2, PD_REGISTER_KERNEL(trilinear_interp_v2,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::TrilinearInterpKernel, phi::TrilinearInterpKernel,
float, float,
double, double,
int) {} int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2, PD_REGISTER_KERNEL(linear_interp_v2,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::LinearInterpKernel, phi::LinearInterpKernel,
float, float,
double, double,
int) {} int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2, PD_REGISTER_KERNEL(bicubic_interp_v2,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::BicubicInterpKernel, phi::BicubicInterpKernel,
float, float,
double, double,
int) {} int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册