accuracy_op.cu 4.0 KB
Newer Older
武毅 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
武毅 已提交
17
#include "paddle/operators/accuracy_op.h"
18
#include "paddle/platform/cuda_helper.h"
D
dzhwinter 已提交
19
#include "paddle/platform/gpu_info.h"
武毅 已提交
20 21 22

namespace paddle {
namespace operators {
23
using platform::PADDLE_CUDA_NUM_THREADS;
武毅 已提交
24

武毅 已提交
25 26 27
template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N, const int D,
                                   const int64_t* Xdata,
D
Dong Zhihong 已提交
28 29
                                   const int64_t* labeldata, int* correct_data,
                                   float* accuracy) {
30 31 32 33 34 35 36 37
  int count = 0;
  __shared__ int total[BlockSize];

  // support only 1 block
  for (int i = threadIdx.x; i < (N); i += BlockSize) {
    for (int j = 0; j < D; ++j) {
      if (Xdata[i * D + j] == labeldata[i]) {
        ++count;
武毅 已提交
38 39 40 41
        break;
      }
    }
  }
42 43 44 45 46 47
  total[threadIdx.x] = count;
  __syncthreads();

  // reduce the count with init value 0, and output accuracy.
  int result = thrust::reduce(thrust::device, total, total + BlockSize, 0);
  if (threadIdx.x == 0) {
D
Dong Zhihong 已提交
48
    *correct_data = result;
49 50
    *accuracy = static_cast<float>(result) / static_cast<float>(N);
  }
武毅 已提交
51 52 53
}

template <typename T>
Y
Yu Yang 已提交
54
class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
武毅 已提交
55 56 57 58
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "It must use GPUPlace.");
武毅 已提交
59 60
    auto* inference = ctx.Input<Tensor>("Out");
    auto* indices = ctx.Input<Tensor>("Indices");
武毅 已提交
61
    auto* label = ctx.Input<Tensor>("Label");
D
Dong Zhihong 已提交
62

武毅 已提交
63
    auto* accuracy = ctx.Output<Tensor>("Accuracy");
D
Dong Zhihong 已提交
64 65
    auto* correct = ctx.Output<Tensor>("Correct");
    auto* total = ctx.Output<Tensor>("Total");
武毅 已提交
66 67
    // FIXME(typhoonzero): only support indices currently
    // if add support for output values, how to detect the data type?
武毅 已提交
68 69
    const int64_t* indices_data = indices->data<int64_t>();
    const int64_t* label_data = label->data<int64_t>();
D
Dong Zhihong 已提交
70 71 72

    int* correct_data = correct->mutable_data<int>(ctx.GetPlace());
    int* total_data = total->mutable_data<int>(ctx.GetPlace());
武毅 已提交
73 74
    float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());

D
Dong Zhihong 已提交
75
    int num_samples = static_cast<int>(inference->dims()[0]);
武毅 已提交
76
    size_t infer_width = inference->dims()[1];
D
dzhwinter 已提交
77 78
    auto stream = ctx.cuda_device_context().stream();
    platform::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream);
武毅 已提交
79 80 81 82

    if (num_samples == 0) {
      return;
    }
D
dzhwinter 已提交
83 84
    platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
                             cudaMemcpyHostToDevice, stream);
武毅 已提交
85

D
dzhwinter 已提交
86 87
    AccuracyCudaKernel<
        PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
D
Dong Zhihong 已提交
88 89 90 91 92
        num_samples, infer_width, indices_data, label_data, correct_data,
        accuracy_data);

    int d_num_samples, d_num_correct;
    float d_accuracy;
D
dzhwinter 已提交
93 94 95 96 97 98
    platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
                             cudaMemcpyDeviceToHost, stream);
    platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
                             cudaMemcpyDeviceToHost, stream);
    platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
                             cudaMemcpyDeviceToHost, stream);
武毅 已提交
99 100 101 102 103 104
  }
};

}  // namespace operators
}  // namespace paddle

D
Dong Zhihong 已提交
105 106
// FIXME(typhoonzero): types of T is for inference data.
// label data is always int64
武毅 已提交
107 108
REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
                       paddle::operators::AccuracyOpCUDAKernel<double>);