diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 0aac6f969beb7232f9e4d4efae92d186dd9e8c53..8c87abf4fd86a51fffc65599878c0852bcfa60b8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -844,7 +844,9 @@ void EighInferMeta(const MetaTensor& x, values_dim.emplace_back(input_dim[i]); } out_w->set_dims(phi::make_ddim(values_dim)); + out_w->set_dtype(dtype::ToReal(x.dtype())); out_v->set_dims(input_dim); + out_v->set_dtype(dtype::ToReal(x.dtype())); } void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) { diff --git a/paddle/phi/kernels/cpu/eigh_kernel.cc b/paddle/phi/kernels/cpu/eigh_kernel.cc index 0f0a10c8377921b81e189f36bc6a926771876d26..cfb17589d505dd8d5f8c85402515f8870366bd76 100644 --- a/paddle/phi/kernels/cpu/eigh_kernel.cc +++ b/paddle/phi/kernels/cpu/eigh_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/eigh_kernel.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/values_vectors_functor.h" @@ -40,4 +41,7 @@ PD_REGISTER_KERNEL(eigh, float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/gpu/eigh_kernel.cu b/paddle/phi/kernels/gpu/eigh_kernel.cu index 3ffbb0b95b666523122265276c731c3252ccacc4..b5548da2c71416cf3320f4a4422a6c246891534e 100644 --- a/paddle/phi/kernels/gpu/eigh_kernel.cu +++ b/paddle/phi/kernels/gpu/eigh_kernel.cu @@ -17,6 +17,7 @@ #include "paddle/phi/kernels/eigh_kernel.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/values_vectors_functor.h" @@ -43,6 +44,9 @@ PD_REGISTER_KERNEL(eigh, // cuda_only float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); +} #endif // not PADDLE_WITH_HIP