From 39899d79fa31022d885d21d5d432aa96fbb218fa Mon Sep 17 00:00:00 2001 From: Little-chick <74541422+Little-chick@users.noreply.github.com> Date: Mon, 13 Mar 2023 10:43:51 +0800 Subject: [PATCH] add register of auc (#51451) * Update interpreter_util.cc * Update auc_kernel.cc * Update auc_kernel.cu * Update auc_kernel.cc * Update auc_kernel.cu --- .../framework/new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/kernels/cpu/auc_kernel.cc | 6 +++++- paddle/phi/kernels/gpu/auc_kernel.cu | 6 +++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index fddeebc2153..161ad1a3f3e 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -56,7 +56,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "any_raw", "arg_sort", "atan2", - "auc", "clip_by_norm", "complex", "conv3d_coo", diff --git a/paddle/phi/kernels/cpu/auc_kernel.cc b/paddle/phi/kernels/cpu/auc_kernel.cc index 0cf85348e6a..647fc592d0e 100644 --- a/paddle/phi/kernels/cpu/auc_kernel.cc +++ b/paddle/phi/kernels/cpu/auc_kernel.cc @@ -207,4 +207,8 @@ void AucKernel(const Context &dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(auc, CPU, ALL_LAYOUT, phi::AucKernel, float) {} +PD_REGISTER_KERNEL(auc, CPU, ALL_LAYOUT, phi::AucKernel, float) { + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT64); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(2).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/auc_kernel.cu b/paddle/phi/kernels/gpu/auc_kernel.cu index c815f33a667..f733df24cf8 100644 --- a/paddle/phi/kernels/gpu/auc_kernel.cu +++ b/paddle/phi/kernels/gpu/auc_kernel.cu @@ -273,4 +273,8 @@ void AucKernel(const Context &dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(auc, GPU, ALL_LAYOUT, phi::AucKernel, float) {} +PD_REGISTER_KERNEL(auc, GPU, ALL_LAYOUT, phi::AucKernel, float) { + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT64); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(2).SetDataType(phi::DataType::INT64); +} -- GitLab