/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "paddle/fluid/operators/metrics/accuracy_op.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/gpu_info.h" namespace paddle { namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; template __global__ void AccuracyCudaKernel(const int N, const int D, const int64_t* Xdata, const int64_t* labeldata, int* correct_data, float* accuracy, int* total_data) { int count = 0; __shared__ int total[BlockSize]; // support only 1 block for (int i = threadIdx.x; i < (N); i += BlockSize) { for (int j = 0; j < D; ++j) { if (Xdata[i * D + j] == labeldata[i]) { ++count; break; } } } total[threadIdx.x] = count; __syncthreads(); // reduce the count with init value 0, and output accuracy. int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); if (threadIdx.x == 0) { *correct_data = result; *accuracy = static_cast(result) / static_cast(N); *total_data = N; } } template class AccuracyOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* inference = ctx.Input("Out"); auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); auto* accuracy = ctx.Output("Accuracy"); auto* correct = ctx.Output("Correct"); auto* total = ctx.Output("Total"); // FIXME(typhoonzero): only support indices currently // if add support for output values, how to detect the data type? const int64_t* indices_data = indices->data(); const int64_t* label_data = label->data(); int* correct_data = correct->mutable_data(ctx.GetPlace()); int* total_data = total->mutable_data(ctx.GetPlace()); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); int num_samples = static_cast(inference->dims()[0]); size_t infer_width = inference->dims()[1]; auto stream = ctx.cuda_device_context().stream(); platform::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream); if (num_samples == 0) { return; } AccuracyCudaKernel< PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( num_samples, infer_width, indices_data, label_data, correct_data, accuracy_data, total_data); } }; } // namespace operators } // namespace paddle // FIXME(typhoonzero): types of T is for inference data. // label data is always int64 REGISTER_OP_CUDA_KERNEL( accuracy, paddle::operators::AccuracyOpCUDAKernel, paddle::operators::AccuracyOpCUDAKernel, paddle::operators::AccuracyOpCUDAKernel);