From 5cb95856b60b2858e723f36a5b7e9fb136ecda3f Mon Sep 17 00:00:00 2001 From: Infinity_lee Date: Wed, 15 Mar 2023 20:57:00 +0800 Subject: [PATCH] add output defs for eig kernel (#51319) * fix eig * fix * fix * fix * fix --- .../new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/infermeta/unary.cc | 8 ++++++-- paddle/phi/kernels/cpu/eig_kernel.cc | 5 ++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index b2ef8c2f7dc..47b2c2a2b47 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -57,7 +57,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "arg_sort", "atan2", "clip_by_norm", - "eig", "eig_grad", "eigh", "ftt_c2r", diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 68e7f24cb9d..42f8575450e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -769,7 +769,7 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) { } void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { - auto x_dims = x.dims(); + phi::DDim x_dims = x.dims(); int rank = x_dims.size(); PADDLE_ENFORCE_GE( rank, @@ -789,9 +789,13 @@ void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { for (int i = 0; i < rank - 1; ++i) { batch_dims_vec.emplace_back(x_dims[i]); } - + const DataType& x_dtype = x.dtype(); + const DataType& out_dtype = + IsComplexType(x_dtype) ? x_dtype : ToComplexType(x_dtype); out_w->set_dims(phi::make_ddim(batch_dims_vec)); + out_w->set_dtype(out_dtype); out_v->set_dims(x_dims); + out_v->set_dtype(out_dtype); } void EighInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/cpu/eig_kernel.cc b/paddle/phi/kernels/cpu/eig_kernel.cc index c9bdf8af116..0ef1e199650 100644 --- a/paddle/phi/kernels/cpu/eig_kernel.cc +++ b/paddle/phi/kernels/cpu/eig_kernel.cc @@ -104,4 +104,7 @@ PD_REGISTER_KERNEL(eig, float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); +} -- GitLab