From ff803bdc94aba0c933a00df71feadd376d267e61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Thu, 9 Mar 2023 10:41:32 +0800 Subject: [PATCH] add REGISTER of as_real (#51313) --- .../framework/new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/infermeta/unary.cc | 1 + paddle/phi/kernels/cpu/as_real_kernel.cc | 4 +++- paddle/phi/kernels/gpu/as_real_kernel.cu | 4 +++- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 1f6cd392ec0..1baf27928bd 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -58,7 +58,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "angle", "any_raw", "arg_sort", - "as_real", "atan2", "auc", "bincount", diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 0f001ff023e..b2e9e653539 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -273,6 +273,7 @@ void AsRealInferMeta(const MetaTensor& input, MetaTensor* output) { auto out_dims = phi::make_ddim(out_dims_v); output->set_dims(out_dims); output->share_lod(input); + output->set_dtype(dtype::ToReal(input.dtype())); } void AsComplexInferMeta(const MetaTensor& input, MetaTensor* output) { diff --git a/paddle/phi/kernels/cpu/as_real_kernel.cc b/paddle/phi/kernels/cpu/as_real_kernel.cc index eb7584a28d6..5541a887c9f 100644 --- a/paddle/phi/kernels/cpu/as_real_kernel.cc +++ b/paddle/phi/kernels/cpu/as_real_kernel.cc @@ -23,4 +23,6 @@ using complex64 = ::phi::dtype::complex; using complex128 = ::phi::dtype::complex; PD_REGISTER_KERNEL( - as_real, CPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) {} + as_real, CPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/gpu/as_real_kernel.cu b/paddle/phi/kernels/gpu/as_real_kernel.cu index c6cb3aca226..83976381844 100644 --- a/paddle/phi/kernels/gpu/as_real_kernel.cu +++ b/paddle/phi/kernels/gpu/as_real_kernel.cu @@ -23,4 +23,6 @@ using complex64 = ::phi::dtype::complex; using complex128 = ::phi::dtype::complex; PD_REGISTER_KERNEL( - as_real, GPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) {} + as_real, GPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} -- GitLab