未验证 提交 5cb95856 编写于 作者: I Infinity_lee 提交者: GitHub

add output defs for eig kernel (#51319)

* fix eig

* fix

* fix

* fix

* fix
上级 297182f7
...@@ -57,7 +57,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -57,7 +57,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"arg_sort", "arg_sort",
"atan2", "atan2",
"clip_by_norm", "clip_by_norm",
"eig",
"eig_grad", "eig_grad",
"eigh", "eigh",
"ftt_c2r", "ftt_c2r",
......
...@@ -769,7 +769,7 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) { ...@@ -769,7 +769,7 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) {
} }
void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) {
auto x_dims = x.dims(); phi::DDim x_dims = x.dims();
int rank = x_dims.size(); int rank = x_dims.size();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
rank, rank,
...@@ -789,9 +789,13 @@ void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { ...@@ -789,9 +789,13 @@ void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) {
for (int i = 0; i < rank - 1; ++i) { for (int i = 0; i < rank - 1; ++i) {
batch_dims_vec.emplace_back(x_dims[i]); batch_dims_vec.emplace_back(x_dims[i]);
} }
const DataType& x_dtype = x.dtype();
const DataType& out_dtype =
IsComplexType(x_dtype) ? x_dtype : ToComplexType(x_dtype);
out_w->set_dims(phi::make_ddim(batch_dims_vec)); out_w->set_dims(phi::make_ddim(batch_dims_vec));
out_w->set_dtype(out_dtype);
out_v->set_dims(x_dims); out_v->set_dims(x_dims);
out_v->set_dtype(out_dtype);
} }
void EighInferMeta(const MetaTensor& x, void EighInferMeta(const MetaTensor& x,
......
...@@ -104,4 +104,7 @@ PD_REGISTER_KERNEL(eig, ...@@ -104,4 +104,7 @@ PD_REGISTER_KERNEL(eig,
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册