From aba3c806b697d756a611585527e1705e721924d4 Mon Sep 17 00:00:00 2001 From: Leo Guo <58431564+ZibinGuo@users.noreply.github.com> Date: Tue, 8 Nov 2022 14:59:17 +0800 Subject: [PATCH] Fix bug of abs_double_grad in eager mode for kunlun, test=kunlun (#47722) --- paddle/phi/api/yaml/legacy_backward.yaml | 2 -- paddle/phi/kernels/cpu/abs_grad_kernel.cc | 4 +++- paddle/phi/kernels/gpu/abs_grad_kernel.cu | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 29af9df55b..224936a4d2 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -7,8 +7,6 @@ param : [x] kernel : func : abs_double_grad - data_transform: - skip_transform : grad_x_grad - backward_op : abs_grad forward : abs (Tensor x) -> Tensor(out) diff --git a/paddle/phi/kernels/cpu/abs_grad_kernel.cc b/paddle/phi/kernels/cpu/abs_grad_kernel.cc index f32a9a075c..db6fff065c 100644 --- a/paddle/phi/kernels/cpu/abs_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/abs_grad_kernel.cc @@ -41,4 +41,6 @@ PD_REGISTER_KERNEL(abs_double_grad, int, int64_t, complex, - complex) {} + complex) { + kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/gpu/abs_grad_kernel.cu b/paddle/phi/kernels/gpu/abs_grad_kernel.cu index 810aaf6e2a..8edb6b7122 100644 --- a/paddle/phi/kernels/gpu/abs_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/abs_grad_kernel.cu @@ -45,4 +45,6 @@ PD_REGISTER_KERNEL(abs_double_grad, int64_t, phi::dtype::float16, complex, - complex) {} + complex) { + kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} -- GitLab