mean_iou_op.cu 6.2 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/mean_iou_op.h"
16
#include "paddle/fluid/memory/malloc.h"
17
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
18
#include "paddle/phi/backends/gpu/gpu_primitives.h"
19
#include "paddle/phi/kernels/funcs/math_function.h"
W
whs 已提交
20 21 22 23

namespace paddle {
namespace operators {

24
using phi::PADDLE_CUDA_NUM_THREADS;
W
whs 已提交
25 26

template <typename T>
27 28 29 30 31 32
__global__ void CountCUDAKernel(const int num_classes,
                                const int count,
                                const T* predictions,
                                const T* labels,
                                int* wrong,
                                int* correct) {
W
whs 已提交
33 34 35 36 37 38 39 40 41 42 43
  extern __shared__ int blcok_cache[];
  int* wrong_c = blcok_cache;
  int* correct_c = blcok_cache + num_classes;
  // init cache
  for (int i = threadIdx.x; i < num_classes * 2; i += blockDim.x) {
    blcok_cache[i] = 0;
  }
  __syncthreads();

  T pred;
  T label;
44
  CUDA_KERNEL_LOOP(i, count) {
W
whs 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    pred = predictions[i];
    label = labels[i];
    if (pred == label) {
      atomicAdd(correct_c + pred, 1);
    } else {
      atomicAdd(wrong_c + pred, 1);
      atomicAdd(wrong_c + label, 1);
    }
  }

  __syncthreads();

  for (int i = threadIdx.x; i < num_classes; i += blockDim.x) {
    atomicAdd(wrong + i, wrong_c[i]);
    atomicAdd(correct + i, correct_c[i]);
  }
}

63 64
__global__ void ComputeIoUCUDAKernel(
    const int num_classes, int* wrong, int* correct, float* ious, float* iou) {
W
whs 已提交
65 66 67 68 69
  __shared__ int valid_count_c;
  if (threadIdx.x == 0) {
    valid_count_c = 0;
  }
  __syncthreads();
70
  CUDA_KERNEL_LOOP(i, num_classes) {
W
whs 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
    int wrong_n = wrong[i];
    int correct_n = correct[i];
    int denominator = wrong_n + correct_n;
    if (denominator > 0) {
      atomicAdd(&valid_count_c, 1);
      ious[i] = static_cast<float>(correct_n) / denominator;
    } else {
      ious[i] = 0;
    }
  }
  __syncthreads();
  if (threadIdx.x == 0) {
    float iou_sum = 0;
    for (int i = 0; i < num_classes; ++i) {
      iou_sum += ious[i];
    }
    iou[0] += iou_sum / valid_count_c;
  }
}

template <typename T>
class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
L
Leo Chen 已提交
95
    auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
96
    auto& place = *dev_ctx.eigen_device();
W
whs 已提交
97
    // get input and output tensor
98 99 100 101 102
    auto* predictions = ctx.Input<phi::DenseTensor>("Predictions");
    auto* labels = ctx.Input<phi::DenseTensor>("Labels");
    auto* out_mean_iou = ctx.Output<phi::DenseTensor>("OutMeanIou");
    auto* out_wrong = ctx.Output<phi::DenseTensor>("OutWrong");
    auto* out_correct = ctx.Output<phi::DenseTensor>("OutCorrect");
W
whs 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    int num_classes = static_cast<int>(ctx.Attr<int>("num_classes"));

    // Get data ptr
    const T* predictions_data = predictions->data<T>();
    const T* labels_data = labels->data<T>();
    int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace());
    int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace());
    float* out_mean_iou_data =
        out_mean_iou->mutable_data<float>(ctx.GetPlace());

    // Get Eigen tensor
    auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou);
    auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong);
    auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);

118
    // Temporary memory
119 120 121 122
    auto tmp_ious_data = memory::Alloc(
        dev_ctx.GetPlace(),
        num_classes * sizeof(float),
        phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
123
    float* ious_data = static_cast<float*>(tmp_ious_data->ptr());
W
whs 已提交
124 125 126 127 128 129 130

    // Init out_wrong, out_correct and out_mean_iou
    out_wrong_t.device(place) = out_wrong_t.constant(0);
    out_correct_t.device(place) = out_correct_t.constant(0);
    out_mean_iou_t.device(place) = out_mean_iou_t.constant(0.0f);

    // collect pre wrong, correct and mean_iou
131
    auto in_mean_ious = ctx.MultiInput<phi::DenseTensor>("InMeanIou");
W
whs 已提交
132 133 134 135
    for (int i = 0; i < in_mean_ious.size(); ++i) {
      out_mean_iou_t.device(place) +=
          EigenTensor<float, 1>::From(*in_mean_ious[i]);
    }
136
    auto in_wrongs = ctx.MultiInput<phi::DenseTensor>("InWrongs");
W
whs 已提交
137 138 139
    for (int i = 0; i < in_wrongs.size(); ++i) {
      out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]);
    }
140
    auto in_corrects = ctx.MultiInput<phi::DenseTensor>("InCorrects");
W
whs 已提交
141 142 143 144 145 146 147 148
    for (int i = 0; i < in_corrects.size(); ++i) {
      out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]);
    }
    // compute
    auto stream = ctx.cuda_device_context().stream();
    int block = PADDLE_CUDA_NUM_THREADS;
    int grid = (predictions->numel() + block - 1) / block;
    int cache_size = (num_classes * 2 + 1) * sizeof(int);
149 150 151 152 153 154 155 156 157 158 159 160
    CountCUDAKernel<T>
        <<<grid, block, cache_size, stream>>>(num_classes,
                                              predictions->numel(),
                                              predictions_data,
                                              labels_data,
                                              out_wrong_data,
                                              out_correct_data);

    ComputeIoUCUDAKernel<<<1, block, 0, stream>>>(num_classes,
                                                  out_wrong_data,
                                                  out_correct_data,
                                                  ious_data,
W
whs 已提交
161 162 163 164 165 166 167 168
                                                  out_mean_iou_data);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
169 170
REGISTER_OP_CUDA_KERNEL(mean_iou,
                        ops::MeanIoUCUDAOpKernel<int>,
W
whs 已提交
171 172
                        ops::MeanIoUCUDAOpKernel<int64_t>,
                        ops::MeanIoUCUDAOpKernel<int32_t>);