未验证 提交 5cb95856 编写于 作者: I Infinity_lee 提交者: GitHub

add output defs for eig kernel (#51319)

* fix eig

* fix

* fix

* fix

* fix
上级 297182f7
......@@ -57,7 +57,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"arg_sort",
"atan2",
"clip_by_norm",
"eig",
"eig_grad",
"eigh",
"ftt_c2r",
......
......@@ -769,7 +769,7 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) {
}
void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) {
auto x_dims = x.dims();
phi::DDim x_dims = x.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_GE(
rank,
......@@ -789,9 +789,13 @@ void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) {
for (int i = 0; i < rank - 1; ++i) {
batch_dims_vec.emplace_back(x_dims[i]);
}
const DataType& x_dtype = x.dtype();
const DataType& out_dtype =
IsComplexType(x_dtype) ? x_dtype : ToComplexType(x_dtype);
out_w->set_dims(phi::make_ddim(batch_dims_vec));
out_w->set_dtype(out_dtype);
out_v->set_dims(x_dims);
out_v->set_dtype(out_dtype);
}
void EighInferMeta(const MetaTensor& x,
......
......@@ -104,4 +104,7 @@ PD_REGISTER_KERNEL(eig,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册