diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index a4fea714101cd877ec9dde16bdf500aa0548138f..0d00f91998d7d8201f7a76fc1827223c25891f14 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -19,8 +19,6 @@ param : [x] kernel : func : abs_grad - data_transform: - skip_transform : out_grad backward : abs_double_grad - backward_op : acos_grad diff --git a/paddle/phi/kernels/cpu/abs_grad_kernel.cc b/paddle/phi/kernels/cpu/abs_grad_kernel.cc index ca42a5eb2976f62708544e3d3bdd31f63d2a004f..f32a9a075ce157d416b972dcb42814220ba3af8f 100644 --- a/paddle/phi/kernels/cpu/abs_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/abs_grad_kernel.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/common/complex.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/impl/abs_grad_kernel_impl.h" @@ -28,7 +29,9 @@ PD_REGISTER_KERNEL(abs_grad, int, int64_t, complex, - complex) {} + complex) { + kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} PD_REGISTER_KERNEL(abs_double_grad, CPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/abs_grad_kernel.cu b/paddle/phi/kernels/gpu/abs_grad_kernel.cu index 2d96a7a88e33ea45ad138896017f6fe860619808..810aaf6e2afb9b412b097e6b9c722f706c77ebf9 100644 --- a/paddle/phi/kernels/gpu/abs_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/abs_grad_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/abs_grad_kernel_impl.h" @@ -31,7 +32,9 @@ PD_REGISTER_KERNEL(abs_grad, int64_t, phi::dtype::float16, complex, - complex) {} + complex) { + kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} PD_REGISTER_KERNEL(abs_double_grad, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/xpu/abs_grad_kernel.cc b/paddle/phi/kernels/xpu/abs_grad_kernel.cc index e49beee6847a525f54623e67ea1c48036dc83bfd..b9fab28254d29c87f2e10c2100c4b418a97b8ae2 100644 --- a/paddle/phi/kernels/xpu/abs_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/abs_grad_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/abs_grad_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -35,4 +36,6 @@ void AbsGradKernel(const Context& ctx, } } // namespace phi -PD_REGISTER_KERNEL(abs_grad, XPU, ALL_LAYOUT, phi::AbsGradKernel, float) {} +PD_REGISTER_KERNEL(abs_grad, XPU, ALL_LAYOUT, phi::AbsGradKernel, float) { + kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +}