diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index fdf5b1235f0fbe94380282309e7e5083dabad73b..c11e48ec94c36aa1b04ba4c5cb3fa791be12f3d5 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -58,8 +58,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "angle", "any_raw", "arg_sort", - "argmax", - "argmin", "as_real", "atan2", "auc", diff --git a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc index 694698050a0c06cf1fbd4452ead9c25e542ecaa5..77cf4cc0b03e0902c407dbbfaa8aea743496c5b5 100644 --- a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc @@ -195,7 +195,9 @@ PD_REGISTER_KERNEL(argmin, int32_t, int64_t, int16_t, - uint8_t) {} + uint8_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} PD_REGISTER_KERNEL(argmax, CPU, @@ -206,4 +208,6 @@ PD_REGISTER_KERNEL(argmax, int32_t, int64_t, int16_t, - uint8_t) {} + uint8_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu index 199ecc8e5b9890b4cb6a7096a1452bd4f45dad5a..c42ad005c306c180d6fb9d02a0c39aa22c577df0 100644 --- a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu +++ b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu @@ -266,7 +266,9 @@ PD_REGISTER_KERNEL(argmin, int32_t, int64_t, int16_t, - uint8_t) {} + uint8_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} PD_REGISTER_KERNEL(argmax, GPU, @@ -279,4 +281,6 @@ PD_REGISTER_KERNEL(argmax, int32_t, int64_t, int16_t, - uint8_t) {} + uint8_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc index ebf13142345cee3fce41f1208acfcb09b0827e18..555d7bf16e85649043c78e7dda9da63869a37686 100644 --- a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc @@ -74,4 +74,6 @@ void ArgMaxKernel(const Context& dev_ctx, XPUAPIErrorMsg[r])); } } // namespace phi -PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {} +PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +}