未验证 提交 da0c7e14 编写于 作者: L LinearTemporalLogic 提交者: GitHub

Add output defs for eigh kernel (#51362)

* Add output defs for eigh kernel

* fix

* update

* update

* fix

* fix
上级 9eda000c
...@@ -844,7 +844,9 @@ void EighInferMeta(const MetaTensor& x, ...@@ -844,7 +844,9 @@ void EighInferMeta(const MetaTensor& x,
values_dim.emplace_back(input_dim[i]); values_dim.emplace_back(input_dim[i]);
} }
out_w->set_dims(phi::make_ddim(values_dim)); out_w->set_dims(phi::make_ddim(values_dim));
out_w->set_dtype(dtype::ToReal(x.dtype()));
out_v->set_dims(input_dim); out_v->set_dims(input_dim);
out_v->set_dtype(dtype::ToReal(x.dtype()));
} }
void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) { void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/eigh_kernel.h" #include "paddle/phi/kernels/eigh_kernel.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h" #include "paddle/phi/kernels/funcs/values_vectors_functor.h"
...@@ -40,4 +41,7 @@ PD_REGISTER_KERNEL(eigh, ...@@ -40,4 +41,7 @@ PD_REGISTER_KERNEL(eigh,
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/kernels/eigh_kernel.h" #include "paddle/phi/kernels/eigh_kernel.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h" #include "paddle/phi/kernels/funcs/values_vectors_functor.h"
...@@ -43,6 +44,9 @@ PD_REGISTER_KERNEL(eigh, // cuda_only ...@@ -43,6 +44,9 @@ PD_REGISTER_KERNEL(eigh, // cuda_only
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
#endif // not PADDLE_WITH_HIP #endif // not PADDLE_WITH_HIP
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册