From 0f79444e3d344ea00554f1c1e158cfd93f00c5a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Thu, 9 Mar 2023 11:48:11 +0800 Subject: [PATCH] [phi] add register of accuracy (#51308) * add REGISTER of float32 in accuracy * fix something --- .../framework/new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/kernels/cpu/accuracy_kernel.cc | 3 +++ paddle/phi/kernels/gpu/accuracy_kernel.cu | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 05eee2e7f96..6bea96b8796 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -51,7 +51,6 @@ using VariableIdMap = std::map>; // These Op needs set output dtype when register phi kernel, but they didn't static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "abs", - "accuracy", "adam", "adamw", "all_close", diff --git a/paddle/phi/kernels/cpu/accuracy_kernel.cc b/paddle/phi/kernels/cpu/accuracy_kernel.cc index 2c9312e63ac..4f39d28816a 100644 --- a/paddle/phi/kernels/cpu/accuracy_kernel.cc +++ b/paddle/phi/kernels/cpu/accuracy_kernel.cc @@ -96,4 +96,7 @@ PD_REGISTER_KERNEL( accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) { kernel->InputAt(1).SetDataType(phi::DataType::INT64); kernel->InputAt(2).SetDataType(phi::DataType::INT64); + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(2).SetDataType(phi::DataType::INT64); } diff --git a/paddle/phi/kernels/gpu/accuracy_kernel.cu b/paddle/phi/kernels/gpu/accuracy_kernel.cu index 6cdad23bfd5..f67605714ab 100644 --- a/paddle/phi/kernels/gpu/accuracy_kernel.cu +++ b/paddle/phi/kernels/gpu/accuracy_kernel.cu @@ -140,4 +140,6 @@ PD_REGISTER_KERNEL(accuracy, double) { kernel->InputAt(1).SetDataType(phi::DataType::INT64); kernel->InputAt(2).SetDataType(phi::DataType::INT64); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(2).SetDataType(phi::DataType::INT64); } -- GitLab