diff --git a/paddle/fluid/operators/metrics/accuracy_op_xpu.cc b/paddle/fluid/operators/metrics/accuracy_op_xpu.cc deleted file mode 100644 index 0ac30b3e8734718fb314acfea554bbe8a67f4fd6..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/metrics/accuracy_op_xpu.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -template -class AccuracyXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* inference = ctx.Input("Out"); - auto* indices = ctx.Input("Indices"); - auto* label = ctx.Input("Label"); - auto* accuracy = ctx.Output("Accuracy"); - auto* correct = ctx.Output("Correct"); - auto* total = ctx.Output("Total"); - int* correct_data = correct->mutable_data(ctx.GetPlace()); - int* total_data = total->mutable_data(ctx.GetPlace()); - float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); - const int64_t* indices_data = indices->data(); - const int64_t* label_data = label->data(); - size_t num_samples = inference->dims()[0]; - size_t class_dim = inference->dims()[1]; - if (num_samples == 0) { - return; - } - auto& dev_ctx = ctx.template device_context(); - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int size = num_samples * class_dim; - int* indices_int32_ptr = RAII_GUARD.alloc_l3_or_gm(size); - PADDLE_ENFORCE_XDNN_NOT_NULL(indices_int32_ptr); - int* label_int32_ptr = RAII_GUARD.alloc_l3_or_gm(size); - PADDLE_ENFORCE_XDNN_NOT_NULL(label_int32_ptr); - - int r = xpu::cast( - dev_ctx.x_context(), indices_data, indices_int32_ptr, size); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); - - r = xpu::cast( - dev_ctx.x_context(), label_data, label_int32_ptr, size); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); - - r = xpu::accuracy(dev_ctx.x_context(), - indices_int32_ptr, - label_int32_ptr, - num_samples, - class_dim, - correct_data, - total_data, - accuracy_data); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast_v2"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -PD_REGISTER_STRUCT_KERNEL( - accuracy, XPU, ALL_LAYOUT, ops::AccuracyXPUKernel, float) {} -#endif diff --git a/paddle/phi/kernels/xpu/accuracy_kernel.cc b/paddle/phi/kernels/xpu/accuracy_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..69de35615eda69b21798aec77a6c9bf704670b9b --- /dev/null +++ b/paddle/phi/kernels/xpu/accuracy_kernel.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/accuracy_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void AccuracyRawKernel(const Context& dev_ctx, + const DenseTensor& inference, + const DenseTensor& indices, + const DenseTensor& label, + DenseTensor* accuracy, + DenseTensor* correct, + DenseTensor* total) { + int* correct_data = dev_ctx.template Alloc(correct); + int* total_data = dev_ctx.template Alloc(total); + float* accuracy_data = dev_ctx.template Alloc(accuracy); + + const int64_t* indices_data = indices.data(); + const int64_t* label_data = label.data(); + + PADDLE_ENFORCE_EQ( + inference.dims().size(), + 2, + phi::errors::InvalidArgument( + "Rank(Input) of AccuracyOp must be 2, with shape " + "[sample_number, class_dim], But received rank(Input) is %d", + inference.dims().size())); + + int64_t num_samples = inference.dims()[0]; + int64_t class_dim = inference.dims()[1]; + + int r = xpu::accuracy(dev_ctx.x_context(), + indices_data, + label_data, + num_samples, + class_dim, + correct_data, + total_data, + accuracy_data); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "accuracy"); +} +} // namespace phi + +// TODO(add supported dtype.) +PD_REGISTER_KERNEL(accuracy, XPU, ALL_LAYOUT, phi::AccuracyRawKernel, float) { + kernel->InputAt(1).SetDataType(phi::DataType::INT64); + kernel->InputAt(2).SetDataType(phi::DataType::INT64); + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +}