diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index b2ef8c2f7dca37b910dc7e88be590b2fc15feec5..47b2c2a2b47228f46bb3c0ff0e9cc5ae9d048fbb 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -57,7 +57,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "arg_sort", "atan2", "clip_by_norm", - "eig", "eig_grad", "eigh", "ftt_c2r", diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 68e7f24cb9dc19b9e6aa012a57227143b773e057..42f8575450ef2cb1682333bc8899a117b309a9c3 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -769,7 +769,7 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) { } void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { - auto x_dims = x.dims(); + phi::DDim x_dims = x.dims(); int rank = x_dims.size(); PADDLE_ENFORCE_GE( rank, @@ -789,9 +789,13 @@ void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { for (int i = 0; i < rank - 1; ++i) { batch_dims_vec.emplace_back(x_dims[i]); } - + const DataType& x_dtype = x.dtype(); + const DataType& out_dtype = + IsComplexType(x_dtype) ? x_dtype : ToComplexType(x_dtype); out_w->set_dims(phi::make_ddim(batch_dims_vec)); + out_w->set_dtype(out_dtype); out_v->set_dims(x_dims); + out_v->set_dtype(out_dtype); } void EighInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/cpu/eig_kernel.cc b/paddle/phi/kernels/cpu/eig_kernel.cc index c9bdf8af1168270b41ea4c8f642252c68acf9812..0ef1e19965093d93f992bc1f2afa99d4a9a9eae5 100644 --- a/paddle/phi/kernels/cpu/eig_kernel.cc +++ b/paddle/phi/kernels/cpu/eig_kernel.cc @@ -104,4 +104,7 @@ PD_REGISTER_KERNEL(eig, float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); +}