未验证 提交 da0c7e14 编写于 作者: L LinearTemporalLogic 提交者: GitHub

Add output defs for eigh kernel (#51362)

* Add output defs for eigh kernel

* fix

* update

* update

* fix

* fix
上级 9eda000c
......@@ -844,7 +844,9 @@ void EighInferMeta(const MetaTensor& x,
values_dim.emplace_back(input_dim[i]);
}
out_w->set_dims(phi::make_ddim(values_dim));
out_w->set_dtype(dtype::ToReal(x.dtype()));
out_v->set_dims(input_dim);
out_v->set_dtype(dtype::ToReal(x.dtype()));
}
void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {
......
......@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/eigh_kernel.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h"
......@@ -40,4 +41,7 @@ PD_REGISTER_KERNEL(eigh,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -17,6 +17,7 @@
#include "paddle/phi/kernels/eigh_kernel.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h"
......@@ -43,6 +44,9 @@ PD_REGISTER_KERNEL(eigh, // cuda_only
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
#endif // not PADDLE_WITH_HIP
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册