diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 29af9df55b8facaf20906bb560b0347020370fb9..224936a4d260d66f55f2a5eb292c70ea319f739c 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -7,8 +7,6 @@ param : [x] kernel : func : abs_double_grad - data_transform: - skip_transform : grad_x_grad - backward_op : abs_grad forward : abs (Tensor x) -> Tensor(out) diff --git a/paddle/phi/kernels/cpu/abs_grad_kernel.cc b/paddle/phi/kernels/cpu/abs_grad_kernel.cc index f32a9a075ce157d416b972dcb42814220ba3af8f..db6fff065c0578fa60b673be327602435a8c682a 100644 --- a/paddle/phi/kernels/cpu/abs_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/abs_grad_kernel.cc @@ -41,4 +41,6 @@ PD_REGISTER_KERNEL(abs_double_grad, int, int64_t, complex, - complex) {} + complex) { + kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/gpu/abs_grad_kernel.cu b/paddle/phi/kernels/gpu/abs_grad_kernel.cu index 810aaf6e2afb9b412b097e6b9c722f706c77ebf9..8edb6b71224d6d4a1b601bc632675efe890dbf86 100644 --- a/paddle/phi/kernels/gpu/abs_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/abs_grad_kernel.cu @@ -45,4 +45,6 @@ PD_REGISTER_KERNEL(abs_double_grad, int64_t, phi::dtype::float16, complex, - complex) {} + complex) { + kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +}