From 37dbbbd19adc31563e3d6b4481048d715b99ea99 Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Wed, 8 Mar 2023 14:24:18 +0800 Subject: [PATCH] w/o pre-commit (#51315) --- .../new_executor/interpreter/interpreter_util.cc | 2 -- paddle/phi/kernels/cpu/arg_min_max_kernel.cc | 8 ++++++-- paddle/phi/kernels/gpu/arg_min_max_kernel.cu | 8 ++++++-- paddle/phi/kernels/xpu/arg_min_max_kernel.cc | 4 +++- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index fdf5b1235f0..c11e48ec94c 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -58,8 +58,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "angle", "any_raw", "arg_sort", - "argmax", - "argmin", "as_real", "atan2", "auc", diff --git a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc index 694698050a0..77cf4cc0b03 100644 --- a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc @@ -195,7 +195,9 @@ PD_REGISTER_KERNEL(argmin, int32_t, int64_t, int16_t, - uint8_t) {} + uint8_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} PD_REGISTER_KERNEL(argmax, CPU, @@ -206,4 +208,6 @@ PD_REGISTER_KERNEL(argmax, int32_t, int64_t, int16_t, - uint8_t) {} + uint8_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu index 199ecc8e5b9..c42ad005c30 100644 --- a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu +++ b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu @@ -266,7 +266,9 @@ PD_REGISTER_KERNEL(argmin, int32_t, int64_t, int16_t, - uint8_t) {} + uint8_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} PD_REGISTER_KERNEL(argmax, GPU, @@ -279,4 +281,6 @@ PD_REGISTER_KERNEL(argmax, int32_t, int64_t, int16_t, - uint8_t) {} + uint8_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc index ebf13142345..555d7bf16e8 100644 --- a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc @@ -74,4 +74,6 @@ void ArgMaxKernel(const Context& dev_ctx, XPUAPIErrorMsg[r])); } } // namespace phi -PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {} +PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} -- GitLab