bbox_util.cu.h 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
19
#ifdef __NVCC__
20
#include "cub/cub.cuh"
21 22 23
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
24
namespace cub = hipcub;
25
#endif
26
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
27
#include "paddle/fluid/platform/for_range.h"
28
#include "paddle/phi/kernels/funcs/math_function.h"
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

int const kThreadsPerBlock = sizeof(uint64_t) * 8;

static const double kBBoxClipDefault = std::log(1000.0 / 16.0);

struct RangeInitFunctor {
  int start_;
  int delta_;
  int *out_;
  __device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};

template <typename T>
L
Leo Chen 已提交
50
static void SortDescending(const phi::GPUContext &ctx,
51 52
                           const Tensor &value,
                           Tensor *value_out,
53 54 55 56
                           Tensor *index_out) {
  int num = static_cast<int>(value.numel());
  Tensor index_in_t;
  int *idx_in = index_in_t.mutable_data<int>({num}, ctx.GetPlace());
L
Leo Chen 已提交
57
  platform::ForRange<phi::GPUContext> for_range(ctx, num);
58 59 60 61 62 63 64 65 66
  for_range(RangeInitFunctor{0, 1, idx_in});

  int *idx_out = index_out->mutable_data<int>({num}, ctx.GetPlace());

  const T *keys_in = value.data<T>();
  T *keys_out = value_out->mutable_data<T>({num}, ctx.GetPlace());

  // Determine temporary device storage requirements
  size_t temp_storage_bytes = 0;
67 68 69 70 71 72 73 74 75 76
  cub::DeviceRadixSort::SortPairsDescending<T, int>(nullptr,
                                                    temp_storage_bytes,
                                                    keys_in,
                                                    keys_out,
                                                    idx_in,
                                                    idx_out,
                                                    num,
                                                    0,
                                                    sizeof(T) * 8,
                                                    ctx.stream());
77
  // Allocate temporary storage
78
  auto place = ctx.GetPlace();
79 80
  auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);

81
  // Run sorting operation
82 83 84 85 86 87 88 89 90 91
  cub::DeviceRadixSort::SortPairsDescending<T, int>(d_temp_storage->ptr(),
                                                    temp_storage_bytes,
                                                    keys_in,
                                                    keys_out,
                                                    idx_in,
                                                    idx_out,
                                                    num,
                                                    0,
                                                    sizeof(T) * 8,
                                                    ctx.stream());
92 93 94 95 96 97 98 99 100
}

template <typename T>
struct BoxDecodeAndClipFunctor {
  const T *anchor;
  const T *deltas;
  const T *var;
  const int *index;
  const T *im_info;
101
  const bool pixel_offset;
102 103 104

  T *proposals;

105 106 107 108 109 110
  BoxDecodeAndClipFunctor(const T *anchor,
                          const T *deltas,
                          const T *var,
                          const int *index,
                          const T *im_info,
                          T *proposals,
111
                          bool pixel_offset = true)
112 113 114 115 116
      : anchor(anchor),
        deltas(deltas),
        var(var),
        index(index),
        im_info(im_info),
117 118
        proposals(proposals),
        pixel_offset(pixel_offset) {}
119 120 121 122 123 124 125 126 127 128

  T bbox_clip_default{static_cast<T>(kBBoxClipDefault)};

  __device__ void operator()(size_t i) {
    int k = index[i] * 4;
    T axmin = anchor[k];
    T aymin = anchor[k + 1];
    T axmax = anchor[k + 2];
    T aymax = anchor[k + 3];

129 130 131
    T offset = pixel_offset ? static_cast<T>(1.0) : 0;
    T w = axmax - axmin + offset;
    T h = aymax - aymin + offset;
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    T cx = axmin + 0.5 * w;
    T cy = aymin + 0.5 * h;

    T dxmin = deltas[k];
    T dymin = deltas[k + 1];
    T dxmax = deltas[k + 2];
    T dymax = deltas[k + 3];

    T d_cx, d_cy, d_w, d_h;
    if (var) {
      d_cx = cx + dxmin * w * var[k];
      d_cy = cy + dymin * h * var[k + 1];
      d_w = exp(Min(dxmax * var[k + 2], bbox_clip_default)) * w;
      d_h = exp(Min(dymax * var[k + 3], bbox_clip_default)) * h;
    } else {
      d_cx = cx + dxmin * w;
      d_cy = cy + dymin * h;
      d_w = exp(Min(dxmax, bbox_clip_default)) * w;
      d_h = exp(Min(dymax, bbox_clip_default)) * h;
    }

    T oxmin = d_cx - d_w * 0.5;
    T oymin = d_cy - d_h * 0.5;
155 156
    T oxmax = d_cx + d_w * 0.5 - offset;
    T oymax = d_cy + d_h * 0.5 - offset;
157

158 159 160 161
    proposals[i * 4] = Max(Min(oxmin, im_info[1] - offset), 0.);
    proposals[i * 4 + 1] = Max(Min(oymin, im_info[0] - offset), 0.);
    proposals[i * 4 + 2] = Max(Min(oxmax, im_info[1] - offset), 0.);
    proposals[i * 4 + 3] = Max(Min(oymax, im_info[0] - offset), 0.);
162 163 164 165 166 167 168 169
  }

  __device__ __forceinline__ T Min(T a, T b) const { return a > b ? b : a; }

  __device__ __forceinline__ T Max(T a, T b) const { return a > b ? a : b; }
};

template <typename T, int BlockSize>
170 171 172 173 174 175
static __global__ void FilterBBoxes(const T *bboxes,
                                    const T *im_info,
                                    const T min_size,
                                    const int num,
                                    int *keep_num,
                                    int *keep,
176 177
                                    bool is_scale = true,
                                    bool pixel_offset = true) {
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
  T im_h = im_info[0];
  T im_w = im_info[1];

  int cnt = 0;
  __shared__ int keep_index[BlockSize];

  CUDA_KERNEL_LOOP(i, num) {
    keep_index[threadIdx.x] = -1;
    __syncthreads();

    int k = i * 4;
    T xmin = bboxes[k];
    T ymin = bboxes[k + 1];
    T xmax = bboxes[k + 2];
    T ymax = bboxes[k + 3];
193 194 195 196 197 198 199 200 201 202 203
    T offset = pixel_offset ? static_cast<T>(1.0) : 0;
    T w = xmax - xmin + offset;
    T h = ymax - ymin + offset;
    if (pixel_offset) {
      T cx = xmin + w / 2.;
      T cy = ymin + h / 2.;

      if (is_scale) {
        w = (xmax - xmin) / im_info[2] + 1.;
        h = (ymax - ymin) / im_info[2] + 1.;
      }
204

205 206 207 208 209 210 211
      if (w >= min_size && h >= min_size && cx <= im_w && cy <= im_h) {
        keep_index[threadIdx.x] = i;
      }
    } else {
      if (w >= min_size && h >= min_size) {
        keep_index[threadIdx.x] = i;
      }
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
    }
    __syncthreads();
    if (threadIdx.x == 0) {
      int size = (num - i) < BlockSize ? num - i : BlockSize;
      for (int j = 0; j < size; ++j) {
        if (keep_index[j] > -1) {
          keep[cnt++] = keep_index[j];
        }
      }
    }
    __syncthreads();
  }
  if (threadIdx.x == 0) {
    keep_num[0] = cnt;
  }
}

229 230
static __device__ float IoU(const float *a,
                            const float *b,
231 232
                            const bool pixel_offset = true) {
  float offset = pixel_offset ? static_cast<float>(1.0) : 0;
233 234
  float left = max(a[0], b[0]), right = min(a[2], b[2]);
  float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
235 236
  float width = max(right - left + offset, 0.f),
        height = max(bottom - top + offset, 0.f);
237
  float inter_s = width * height;
238 239
  float s_a = (a[2] - a[0] + offset) * (a[3] - a[1] + offset);
  float s_b = (b[2] - b[0] + offset) * (b[3] - b[1] + offset);
240 241 242 243 244
  return inter_s / (s_a + s_b - inter_s);
}

static __global__ void NMSKernel(const int n_boxes,
                                 const float nms_overlap_thresh,
245 246
                                 const float *dev_boxes,
                                 uint64_t *dev_mask,
247
                                 bool pixel_offset = true) {
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
  const int row_start = blockIdx.y;
  const int col_start = blockIdx.x;

  const int row_size =
      min(n_boxes - row_start * kThreadsPerBlock, kThreadsPerBlock);
  const int col_size =
      min(n_boxes - col_start * kThreadsPerBlock, kThreadsPerBlock);

  __shared__ float block_boxes[kThreadsPerBlock * 4];
  if (threadIdx.x < col_size) {
    block_boxes[threadIdx.x * 4 + 0] =
        dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 0];
    block_boxes[threadIdx.x * 4 + 1] =
        dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 1];
    block_boxes[threadIdx.x * 4 + 2] =
        dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 2];
    block_boxes[threadIdx.x * 4 + 3] =
        dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 3];
  }
  __syncthreads();

  if (threadIdx.x < row_size) {
    const int cur_box_idx = kThreadsPerBlock * row_start + threadIdx.x;
    const float *cur_box = dev_boxes + cur_box_idx * 4;
    int i = 0;
    uint64_t t = 0;
    int start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;
    }
    for (i = start; i < col_size; i++) {
279 280
      if (IoU(cur_box, block_boxes + i * 4, pixel_offset) >
          nms_overlap_thresh) {
281 282 283 284 285 286 287 288 289
        t |= 1ULL << i;
      }
    }
    const int col_blocks = DIVUP(n_boxes, kThreadsPerBlock);
    dev_mask[cur_box_idx * col_blocks + col_start] = t;
  }
}

template <typename T>
L
Leo Chen 已提交
290
static void NMS(const phi::GPUContext &ctx,
291 292 293 294 295
                const Tensor &proposals,
                const Tensor &sorted_indices,
                const T nms_threshold,
                Tensor *keep_out,
                bool pixel_offset = true) {
296 297 298 299 300 301 302
  int boxes_num = proposals.dims()[0];
  const int col_blocks = DIVUP(boxes_num, kThreadsPerBlock);
  dim3 blocks(DIVUP(boxes_num, kThreadsPerBlock),
              DIVUP(boxes_num, kThreadsPerBlock));
  dim3 threads(kThreadsPerBlock);

  const T *boxes = proposals.data<T>();
303
  auto place = ctx.GetPlace();
304 305 306 307 308
  auto mask_ptr = memory::Alloc(ctx, boxes_num * col_blocks * sizeof(uint64_t));
  uint64_t *mask_dev = reinterpret_cast<uint64_t *>(mask_ptr->ptr());

  NMSKernel<<<blocks, threads, 0, ctx.stream()>>>(
      boxes_num, nms_threshold, boxes, mask_dev, pixel_offset);
309 310 311 312

  std::vector<uint64_t> remv(col_blocks);
  memset(&remv[0], 0, sizeof(uint64_t) * col_blocks);

313
  std::vector<uint64_t> mask_host(boxes_num * col_blocks);
314 315 316 317 318 319
  memory::Copy(platform::CPUPlace(),
               mask_host.data(),
               place,
               mask_dev,
               boxes_num * col_blocks * sizeof(uint64_t),
               ctx.stream());
320

321 322 323 324 325 326 327 328 329
  std::vector<int> keep_vec;
  int num_to_keep = 0;
  for (int i = 0; i < boxes_num; i++) {
    int nblock = i / kThreadsPerBlock;
    int inblock = i % kThreadsPerBlock;

    if (!(remv[nblock] & (1ULL << inblock))) {
      ++num_to_keep;
      keep_vec.push_back(i);
330
      uint64_t *p = mask_host.data() + i * col_blocks;
331 332 333 334 335 336
      for (int j = nblock; j < col_blocks; j++) {
        remv[j] |= p[j];
      }
    }
  }
  int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace());
337 338 339 340 341 342
  memory::Copy(place,
               keep,
               platform::CPUPlace(),
               keep_vec.data(),
               sizeof(int) * num_to_keep,
               ctx.stream());
343 344 345 346 347
  ctx.Wait();
}

}  // namespace operators
}  // namespace paddle