diff --git a/paddle/fluid/operators/metrics/accuracy_op.cc b/paddle/fluid/operators/metrics/accuracy_op.cc index 3692ace8bb5a46b06bd10a07a5d5d95d8825bdc6..056620db5b96691b59e4778208b4dafa5a68bd9a 100644 --- a/paddle/fluid/operators/metrics/accuracy_op.cc +++ b/paddle/fluid/operators/metrics/accuracy_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/metrics/accuracy_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -123,13 +123,10 @@ with the input Out(Inference). } // namespace operators } // namespace paddle +// FIXME(typhoonzero): types of T is for infernece data. +// label data is always int. namespace ops = paddle::operators; REGISTER_OPERATOR( accuracy, ops::AccuracyOp, ops::AccuracyOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -// FIXME(typhoonzero): types of T is for infernece data. -// label data is always int. -REGISTER_OP_CPU_KERNEL(accuracy, - ops::AccuracyKernel, - ops::AccuracyKernel); diff --git a/paddle/fluid/operators/metrics/accuracy_op.cu b/paddle/fluid/operators/metrics/accuracy_op.cu deleted file mode 100644 index 6f19100fa9d37e2efedad60a982bf19b09cac736..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/metrics/accuracy_op.cu +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include "paddle/fluid/operators/metrics/accuracy_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace operators { -using platform::PADDLE_CUDA_NUM_THREADS; - -template -__global__ void AccuracyCudaKernel(const int N, const int D, - const int64_t* Xdata, - const int64_t* labeldata, int* correct_data, - float* accuracy, int* total_data) { - int count = 0; - __shared__ int total[BlockSize]; - - // support only 1 block - for (int i = threadIdx.x; i < (N); i += BlockSize) { - for (int j = 0; j < D; ++j) { - if (Xdata[i * D + j] == labeldata[i]) { - ++count; - break; - } - } - } - total[threadIdx.x] = count; - __syncthreads(); - -// reduce the count with init value 0, and output accuracy. -#ifdef PADDLE_WITH_CUDA - int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); -#else - // HIP thrust::reduce not support __device__ - for (int s = BlockSize / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - total[threadIdx.x] += total[threadIdx.x + s]; - } - __syncthreads(); - } - int result = total[0]; -#endif - if (threadIdx.x == 0) { - *correct_data = result; - *accuracy = static_cast(result) / static_cast(N); - *total_data = N; - } -} - -template -class AccuracyOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* inference = ctx.Input("Out"); - auto* indices = ctx.Input("Indices"); - auto* label = ctx.Input("Label"); - - auto* accuracy = ctx.Output("Accuracy"); - auto* correct = ctx.Output("Correct"); - auto* total = ctx.Output("Total"); - // FIXME(typhoonzero): only support indices currently - // if add support for output values, how to detect the data type? - const int64_t* indices_data = indices->data(); - const int64_t* label_data = label->data(); - - int* correct_data = correct->mutable_data(ctx.GetPlace()); - int* total_data = total->mutable_data(ctx.GetPlace()); - float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); - - int num_samples = static_cast(inference->dims()[0]); - size_t infer_width = inference->dims()[1]; - auto stream = ctx.cuda_device_context().stream(); - platform::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream); - - if (num_samples == 0) { - return; - } - - AccuracyCudaKernel< - PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( - num_samples, infer_width, indices_data, label_data, correct_data, - accuracy_data, total_data); - } -}; - -} // namespace operators -} // namespace paddle - -// FIXME(typhoonzero): types of T is for inference data. -// label data is always int64 -REGISTER_OP_CUDA_KERNEL( - accuracy, paddle::operators::AccuracyOpCUDAKernel, - paddle::operators::AccuracyOpCUDAKernel, - paddle::operators::AccuracyOpCUDAKernel); diff --git a/paddle/fluid/operators/metrics/accuracy_op.h b/paddle/fluid/operators/metrics/accuracy_op.h deleted file mode 100644 index 94e5bf8257e67b9fd01aa9ae45a25d90963fef13..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/metrics/accuracy_op.h +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class AccuracyKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* inference = ctx.Input("Out"); - auto* indices = ctx.Input("Indices"); - auto* label = ctx.Input("Label"); - auto* accuracy = ctx.Output("Accuracy"); - auto* correct = ctx.Output("Correct"); - auto* total = ctx.Output("Total"); - - int* correct_data = correct->mutable_data(ctx.GetPlace()); - int* total_data = total->mutable_data(ctx.GetPlace()); - float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); - - const int64_t* indices_data = indices->data(); - const int64_t* label_data = label->data(); - - size_t num_samples = inference->dims()[0]; - size_t class_dim = inference->dims()[1]; - *accuracy_data = 0.0f; - - if (num_samples == 0) { - return; - } - - int num_correct = 0; - // assume inference is already the topk of the output - for (size_t i = 0; i < num_samples; ++i) { - PADDLE_ENFORCE_GE( - label_data[i], 0, - platform::errors::InvalidArgument( - "label of AccuracyOp must >= 0, But received label[%d] is %d", i, - label_data[i])); - for (size_t j = 0; j < class_dim; ++j) { - if (indices_data[i * class_dim + j] == label_data[i]) { - ++num_correct; - break; - } - } - } - - *correct_data = num_correct; - *total_data = num_samples; - *accuracy_data = - static_cast(num_correct) / static_cast(num_samples); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/metrics/accuracy_op_mlu.cc b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc index 2598d3b0277c94a52e1fa14b04c00b595071f312..1ce02ff4525c9692f88ed42b79ff336cc0113c41 100644 --- a/paddle/fluid/operators/metrics/accuracy_op_mlu.cc +++ b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/metrics/accuracy_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" namespace paddle { diff --git a/paddle/fluid/operators/metrics/accuracy_op_npu.cc b/paddle/fluid/operators/metrics/accuracy_op_npu.cc index e83278f88b82a31eb445a0a86e3003e96acf395e..9f2ca4165f33a28902bfe20207b12bad2af49fad 100644 --- a/paddle/fluid/operators/metrics/accuracy_op_npu.cc +++ b/paddle/fluid/operators/metrics/accuracy_op_npu.cc @@ -13,7 +13,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/metrics/accuracy_op.h" +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/metrics/accuracy_op_xpu.cc b/paddle/fluid/operators/metrics/accuracy_op_xpu.cc index de71312d78df99adc3b3663f2fcbb3943373982e..3cc1be4de8a82ff263824ab4852178f735596d45 100644 --- a/paddle/fluid/operators/metrics/accuracy_op_xpu.cc +++ b/paddle/fluid/operators/metrics/accuracy_op_xpu.cc @@ -14,12 +14,14 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/metrics/accuracy_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { +using Tensor = paddle::framework::Tensor; template class AccuracyXPUKernel : public framework::OpKernel { public: diff --git a/paddle/phi/kernels/accuracy_kernel.h b/paddle/phi/kernels/accuracy_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8f2dbb96f86544882c3218a937225dd27978c15f --- /dev/null +++ b/paddle/phi/kernels/accuracy_kernel.h @@ -0,0 +1,30 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AccuracyRawKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& indices, + const DenseTensor& label, + DenseTensor* accuracy, + DenseTensor* correct, + DenseTensor* total); +} // namespace phi diff --git a/paddle/phi/kernels/cpu/accuracy_kernel.cc b/paddle/phi/kernels/cpu/accuracy_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c57ec69b73a230df48411f4074935e2bb4bce461 --- /dev/null +++ b/paddle/phi/kernels/cpu/accuracy_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/accuracy_kernel.h" + +#include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void AccuracyRawKernel(const Context& dev_ctx, + const DenseTensor& inference, + const DenseTensor& indices, + const DenseTensor& label, + DenseTensor* accuracy, + DenseTensor* correct, + DenseTensor* total) { + int* correct_data = dev_ctx.template Alloc(correct); + int* total_data = dev_ctx.template Alloc(total); + float* accuracy_data = dev_ctx.template Alloc(accuracy); + + const int64_t* indices_data = indices.data(); + const int64_t* label_data = label.data(); + + size_t num_samples = inference.dims()[0]; + size_t class_dim = inference.dims()[1]; + *accuracy_data = 0.0f; + + if (num_samples == 0) { + return; + } + + int num_correct = 0; + // assume inference is already the topk of the output + for (size_t i = 0; i < num_samples; ++i) { + PADDLE_ENFORCE_GE( + label_data[i], + 0, + phi::errors::InvalidArgument( + "label of AccuracyOp must >= 0, But received label[%d] is %d", + i, + label_data[i])); + for (size_t j = 0; j < class_dim; ++j) { + if (indices_data[i * class_dim + j] == label_data[i]) { + ++num_correct; + break; + } + } + } + + *correct_data = num_correct; + *total_data = num_samples; + *accuracy_data = + static_cast(num_correct) / static_cast(num_samples); +} +} // namespace phi + +// TODO(add supported dtype.) +PD_REGISTER_KERNEL( + accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/accuracy_kernel.cu b/paddle/phi/kernels/gpu/accuracy_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f08fb74e54d8c86f7b54d21c762e30cebedfe967 --- /dev/null +++ b/paddle/phi/kernels/gpu/accuracy_kernel.cu @@ -0,0 +1,117 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/accuracy_kernel.h" + +#include +#include +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void AccuracyCudaKernel(const int N, + const int D, + const int64_t* Xdata, + const int64_t* labeldata, + int* correct_data, + float* accuracy, + int* total_data) { + int count = 0; + __shared__ int total[BlockSize]; + + // support only 1 block + for (int i = threadIdx.x; i < (N); i += BlockSize) { + for (int j = 0; j < D; ++j) { + if (Xdata[i * D + j] == labeldata[i]) { + ++count; + break; + } + } + } + total[threadIdx.x] = count; + __syncthreads(); + +// reduce the count with init value 0, and output accuracy. +#ifdef PADDLE_WITH_CUDA + int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); +#else + // HIP thrust::reduce not support __device__ + for (int s = BlockSize / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + total[threadIdx.x] += total[threadIdx.x + s]; + } + __syncthreads(); + } + int result = total[0]; +#endif + if (threadIdx.x == 0) { + *correct_data = result; + *accuracy = static_cast(result) / static_cast(N); + *total_data = N; + } +} + +template +void AccuracyRawKernel(const Context& dev_ctx, + const DenseTensor& inference, + const DenseTensor& indices, + const DenseTensor& label, + DenseTensor* accuracy, + DenseTensor* correct, + DenseTensor* total) { + // FIXME(typhoonzero): only support indices currently + // if add support for output values, how to detect the data type? + const int64_t* indices_data = indices.data(); + const int64_t* label_data = label.data(); + + int* correct_data = dev_ctx.template Alloc(correct); + int* total_data = dev_ctx.template Alloc(total); + float* accuracy_data = dev_ctx.template Alloc(accuracy); + + int num_samples = static_cast(inference.dims()[0]); + size_t infer_width = inference.dims()[1]; + auto stream = dev_ctx.stream(); + phi::backends::gpu::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream); + + if (num_samples == 0) { + return; + } + + AccuracyCudaKernel< + PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + num_samples, + infer_width, + indices_data, + label_data, + correct_data, + accuracy_data, + total_data); +} +} // namespace phi + +// FIXME(typhoonzero): types of T is for inference data. +// label data is always int64 +PD_REGISTER_KERNEL(accuracy, + GPU, + ALL_LAYOUT, + phi::AccuracyRawKernel, + phi::dtype::float16, + float, + double) {}