diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 05eee2e7f9649b9986e909ea3dff92a13dd6131d..6bea96b879651983bdc7c0df678ca51668cf14ae 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -51,7 +51,6 @@ using VariableIdMap = std::map>; // These Op needs set output dtype when register phi kernel, but they didn't static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "abs", - "accuracy", "adam", "adamw", "all_close", diff --git a/paddle/phi/kernels/cpu/accuracy_kernel.cc b/paddle/phi/kernels/cpu/accuracy_kernel.cc index 2c9312e63ac89994ed19d1dc77dd36bc20c3e7be..4f39d28816ae3d99a7431aa0c146b05db9c66ecc 100644 --- a/paddle/phi/kernels/cpu/accuracy_kernel.cc +++ b/paddle/phi/kernels/cpu/accuracy_kernel.cc @@ -96,4 +96,7 @@ PD_REGISTER_KERNEL( accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) { kernel->InputAt(1).SetDataType(phi::DataType::INT64); kernel->InputAt(2).SetDataType(phi::DataType::INT64); + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(2).SetDataType(phi::DataType::INT64); } diff --git a/paddle/phi/kernels/gpu/accuracy_kernel.cu b/paddle/phi/kernels/gpu/accuracy_kernel.cu index 6cdad23bfd5e180ecd943e1462de111c2bf318c9..f67605714aba867acbf71c8be06b6462ec1edb66 100644 --- a/paddle/phi/kernels/gpu/accuracy_kernel.cu +++ b/paddle/phi/kernels/gpu/accuracy_kernel.cu @@ -140,4 +140,6 @@ PD_REGISTER_KERNEL(accuracy, double) { kernel->InputAt(1).SetDataType(phi::DataType::INT64); kernel->InputAt(2).SetDataType(phi::DataType::INT64); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(2).SetDataType(phi::DataType::INT64); }