From 383a3f8c7db70eac1bd8806dbe990b80a4d3fd75 Mon Sep 17 00:00:00 2001 From: Zhenghai Zhang <65210872+ccsuzzh@users.noreply.github.com> Date: Mon, 13 Mar 2023 14:30:45 +0800 Subject: [PATCH] Add output defs for mode kernel (#51363) * Add output defs for mode kernel * fix bug --- .../framework/new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/kernels/cpu/mode_kernel.cc | 4 +++- paddle/phi/kernels/gpu/mode_kernel.cu | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index e15fd1bc1a3..0c228abf9a9 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -74,7 +74,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "less_equal", "less_than", "merged_adam", - "mode", "momentum", "multiclass_nms3", "multinomial", diff --git a/paddle/phi/kernels/cpu/mode_kernel.cc b/paddle/phi/kernels/cpu/mode_kernel.cc index 762c146e735..352f5d0b69e 100644 --- a/paddle/phi/kernels/cpu/mode_kernel.cc +++ b/paddle/phi/kernels/cpu/mode_kernel.cc @@ -132,4 +132,6 @@ void ModeKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - mode, CPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {} + mode, CPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/mode_kernel.cu b/paddle/phi/kernels/gpu/mode_kernel.cu index 815ecb9e1a7..c834d87aca9 100644 --- a/paddle/phi/kernels/gpu/mode_kernel.cu +++ b/paddle/phi/kernels/gpu/mode_kernel.cu @@ -130,4 +130,6 @@ void ModeKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - mode, GPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {} + mode, GPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} -- GitLab