From da551b2e05c8d5fbf6158d1487212386e9be8baf Mon Sep 17 00:00:00 2001 From: Sanbu <96160062+sanbuphy@users.noreply.github.com> Date: Sun, 12 Mar 2023 11:31:33 +0800 Subject: [PATCH] Add output defs for all_close all_raw kernel (#51410) * Add output defs for all_close all_raw kernel * Update interpreter_util.cc --- .../new_executor/interpreter/interpreter_util.cc | 2 -- paddle/phi/kernels/cpu/reduce_all_kernel.cc | 4 +++- paddle/phi/kernels/kps/reduce_all_kernel.cu | 8 ++++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 62aab7383a2..b4e6d8a7d12 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -53,8 +53,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "abs", "adam", "adamw", - "all_close", - "all_raw", "any_raw", "arg_sort", "atan2", diff --git a/paddle/phi/kernels/cpu/reduce_all_kernel.cc b/paddle/phi/kernels/cpu/reduce_all_kernel.cc index 60094d1345a..1dea17d4b7a 100644 --- a/paddle/phi/kernels/cpu/reduce_all_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_all_kernel.cc @@ -35,4 +35,6 @@ void AllRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(all_raw, CPU, ALL_LAYOUT, phi::AllRawKernel, bool) {} +PD_REGISTER_KERNEL(all_raw, CPU, ALL_LAYOUT, phi::AllRawKernel, bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} diff --git a/paddle/phi/kernels/kps/reduce_all_kernel.cu b/paddle/phi/kernels/kps/reduce_all_kernel.cu index d4d4596917b..c0c338bb4f2 100644 --- a/paddle/phi/kernels/kps/reduce_all_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_all_kernel.cu @@ -34,7 +34,11 @@ void AllRawKernel(const Context& dev_ctx, } // namespace phi #ifdef PADDLE_WITH_XPU_KP -PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) {} +PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} #else -PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) {} +PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} #endif -- GitLab