accuracy_op.cu 3.0 KB
Newer Older
武毅 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
武毅 已提交
17
#include "paddle/operators/accuracy_op.h"
18
#include "paddle/platform/cuda_helper.h"
武毅 已提交
19 20 21

namespace paddle {
namespace operators {
22
using platform::PADDLE_CUDA_NUM_THREADS;
武毅 已提交
23

武毅 已提交
24 25 26 27
template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N, const int D,
                                   const int64_t* Xdata,
                                   const int64_t* labeldata, float* accuracy) {
28 29 30 31 32 33 34 35
  int count = 0;
  __shared__ int total[BlockSize];

  // support only 1 block
  for (int i = threadIdx.x; i < (N); i += BlockSize) {
    for (int j = 0; j < D; ++j) {
      if (Xdata[i * D + j] == labeldata[i]) {
        ++count;
武毅 已提交
36 37 38 39
        break;
      }
    }
  }
40 41 42 43 44 45 46 47
  total[threadIdx.x] = count;
  __syncthreads();

  // reduce the count with init value 0, and output accuracy.
  int result = thrust::reduce(thrust::device, total, total + BlockSize, 0);
  if (threadIdx.x == 0) {
    *accuracy = static_cast<float>(result) / static_cast<float>(N);
  }
武毅 已提交
48 49 50
}

template <typename T>
Y
Yu Yang 已提交
51
class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
武毅 已提交
52 53 54 55
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "It must use GPUPlace.");
武毅 已提交
56 57
    auto* inference = ctx.Input<Tensor>("Out");
    auto* indices = ctx.Input<Tensor>("Indices");
武毅 已提交
58 59 60 61
    auto* label = ctx.Input<Tensor>("Label");
    auto* accuracy = ctx.Output<Tensor>("Accuracy");
    // FIXME(typhoonzero): only support indices currently
    // if add support for output values, how to detect the data type?
武毅 已提交
62 63
    const int64_t* indices_data = indices->data<int64_t>();
    const int64_t* label_data = label->data<int64_t>();
武毅 已提交
64 65 66 67 68 69 70 71 72 73
    float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());

    size_t num_samples = inference->dims()[0];
    size_t infer_width = inference->dims()[1];
    cudaMemset((void**)&accuracy_data, 0, sizeof(float));

    if (num_samples == 0) {
      return;
    }

武毅 已提交
74
    AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
T
typhoonzero 已提交
75 76
        1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>(
        num_samples, infer_width, indices_data, label_data, accuracy_data);
武毅 已提交
77 78 79 80 81 82
  }
};

}  // namespace operators
}  // namespace paddle

武毅 已提交
83 84 85 86
// FIXME(typhoonzero): types of T is for infernece data.
// label data is always int
REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
                       paddle::operators::AccuracyOpCUDAKernel<double>);