From da0c7e1454b787b65333750b55de89cfa8dd565e Mon Sep 17 00:00:00 2001 From: LinearTemporalLogic <127285600+LinearTemporalLogic@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:04:45 +0800 Subject: [PATCH] Add output defs for eigh kernel (#51362) * Add output defs for eigh kernel * fix * update * update * fix * fix --- paddle/phi/infermeta/unary.cc | 2 ++ paddle/phi/kernels/cpu/eigh_kernel.cc | 6 +++++- paddle/phi/kernels/gpu/eigh_kernel.cu | 6 +++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 0aac6f969be..8c87abf4fd8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -844,7 +844,9 @@ void EighInferMeta(const MetaTensor& x, values_dim.emplace_back(input_dim[i]); } out_w->set_dims(phi::make_ddim(values_dim)); + out_w->set_dtype(dtype::ToReal(x.dtype())); out_v->set_dims(input_dim); + out_v->set_dtype(dtype::ToReal(x.dtype())); } void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) { diff --git a/paddle/phi/kernels/cpu/eigh_kernel.cc b/paddle/phi/kernels/cpu/eigh_kernel.cc index 0f0a10c8377..cfb17589d50 100644 --- a/paddle/phi/kernels/cpu/eigh_kernel.cc +++ b/paddle/phi/kernels/cpu/eigh_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/eigh_kernel.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/values_vectors_functor.h" @@ -40,4 +41,7 @@ PD_REGISTER_KERNEL(eigh, float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/gpu/eigh_kernel.cu b/paddle/phi/kernels/gpu/eigh_kernel.cu index 3ffbb0b95b6..b5548da2c71 100644 --- a/paddle/phi/kernels/gpu/eigh_kernel.cu +++ b/paddle/phi/kernels/gpu/eigh_kernel.cu @@ -17,6 +17,7 @@ #include "paddle/phi/kernels/eigh_kernel.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/values_vectors_functor.h" @@ -43,6 +44,9 @@ PD_REGISTER_KERNEL(eigh, // cuda_only float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); +} #endif // not PADDLE_WITH_HIP -- GitLab