generate_proposals_kernel.cc 15.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/generate_proposals_kernel.h"
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"

#include "paddle/fluid/memory/memcpy.h"

namespace phi {

template <typename T>
static void SortDescending(const XPUContext& dev_ctx,
                           const DenseTensor& value,
                           DenseTensor* index_out,
                           int pre_nms_top_n) {
  auto* value_data = value.data<T>();
  auto place = dev_ctx.GetPlace();
  auto cpu_place = phi::CPUPlace();

  DenseTensor scores_slice_cpu;
  scores_slice_cpu.Resize({value.numel()});
  T* scores_slice_cpu_data = dev_ctx.template HostAlloc<T>(&scores_slice_cpu);

  paddle::memory::Copy(cpu_place,
                       scores_slice_cpu_data,
                       place,
                       value_data,
                       sizeof(T) * value.numel());
  // Sort index
  DenseTensor index_t;
  index_t.Resize({value.numel()});
  int* index = dev_ctx.template HostAlloc<int>(&index_t);
  for (int i = 0; i < value.numel(); ++i) {
    index[i] = i;
  }

  auto compare = [scores_slice_cpu_data](const int64_t& i, const int64_t& j) {
    return scores_slice_cpu_data[i] > scores_slice_cpu_data[j];
  };

  if (pre_nms_top_n <= 0 || pre_nms_top_n >= value.numel()) {
    std::sort(index, index + value.numel(), compare);
  } else {
    std::nth_element(
        index, index + pre_nms_top_n, index + value.numel(), compare);
    std::sort(index, index + pre_nms_top_n, compare);
    index_t.Resize({pre_nms_top_n});
  }

  index_out->Resize({index_t.numel()});
  int* idx_out = dev_ctx.template Alloc<int>(index_out);
  paddle::memory::Copy(
      place, idx_out, cpu_place, index, sizeof(T) * index_t.numel());
}

template <typename T>
std::pair<DenseTensor, DenseTensor> ProposalForOneImage(
    const phi::XPUContext& dev_ctx,
    const DenseTensor& im_shape_slice,
    const DenseTensor& anchors,
    const DenseTensor& variances,
    const DenseTensor& bbox_deltas_slice,  // [M, 4]
    const DenseTensor& scores_slice,       // [N, 1]
    int pre_nms_top_n,
    int post_nms_top_n,
    float nms_thresh,
    float min_size,
    float eta,
    bool pixel_offset = true) {
  // 1. pre nms
  DenseTensor index_sort;
  SortDescending<T>(dev_ctx, scores_slice, &index_sort, pre_nms_top_n);

  DenseTensor scores_sel, bbox_sel, anchor_sel, var_sel;
  scores_sel.Resize(phi::make_ddim({index_sort.numel(), 1}));
  dev_ctx.template Alloc<T>(&scores_sel);

  bbox_sel.Resize(phi::make_ddim({index_sort.numel(), 4}));
  dev_ctx.template Alloc<T>(&bbox_sel);

  anchor_sel.Resize(phi::make_ddim({index_sort.numel(), 4}));
  dev_ctx.template Alloc<T>(&anchor_sel);

  var_sel.Resize(phi::make_ddim({index_sort.numel(), 4}));
  dev_ctx.template Alloc<T>(&var_sel);

  int r = xpu::gather<T>(dev_ctx.x_context(),
                         scores_slice.data<T>(),
                         index_sort.data<int>(),
                         scores_sel.data<T>(),
                         {static_cast<int>(scores_slice.numel()), 1},
                         index_sort.numel(),
                         0);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");

  r = xpu::gather<T>(dev_ctx.x_context(),
                     bbox_deltas_slice.data<T>(),
                     index_sort.data<int>(),
                     bbox_sel.data<T>(),
                     {static_cast<int>(bbox_deltas_slice.numel()) / 4, 4},
                     index_sort.numel(),
                     0);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");

  r = xpu::gather<T>(dev_ctx.x_context(),
                     anchors.data<T>(),
                     index_sort.data<int>(),
                     anchor_sel.data<T>(),
                     {static_cast<int>(anchors.numel()) / 4, 4},
                     index_sort.numel(),
                     0);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");

  r = xpu::gather<T>(dev_ctx.x_context(),
                     variances.data<T>(),
                     index_sort.data<int>(),
                     var_sel.data<T>(),
                     {static_cast<int>(variances.numel()) / 4, 4},
                     index_sort.numel(),
                     0);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");

  int num = scores_slice.numel();
  int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num)
                        ? scores_slice.numel()
                        : pre_nms_top_n;
  scores_sel.Resize({pre_nms_num, 1});
  index_sort.Resize({pre_nms_num, 1});

  // 2. box decode and clipping
  DenseTensor proposals;
  proposals.Resize(phi::make_ddim({index_sort.numel(), 4}));
  dev_ctx.template Alloc<T>(&proposals);

  r = xpu::box_decoder<T>(dev_ctx.x_context(),
                          anchor_sel.data<T>(),
                          var_sel.data<T>(),
                          bbox_sel.data<T>(),
                          proposals.data<T>(),
                          pre_nms_num,
                          !pixel_offset,
                          true,
                          im_shape_slice.data<T>());
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "box_decoder");

  // 3. filter
  DenseTensor keep_index, keep_num_t;
  keep_index.Resize(phi::make_ddim({pre_nms_num}));
  dev_ctx.template Alloc<int>(&keep_index);

  keep_num_t.Resize(phi::make_ddim({1}));
  dev_ctx.template Alloc<int>(&keep_num_t);
  min_size = std::max(min_size, 1.0f);
  r = xpu::remove_small_boxes<T>(dev_ctx.x_context(),
                                 proposals.data<T>(),
                                 im_shape_slice.data<T>(),
                                 keep_index.data<int>(),
                                 keep_num_t.data<int>(),
                                 pre_nms_num,
                                 min_size,
                                 false,
                                 pixel_offset);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "remove_small_boxes");

  int keep_num;
  const auto xpu_place = dev_ctx.GetPlace();
  paddle::memory::Copy(phi::CPUPlace(),
                       &keep_num,
                       xpu_place,
                       keep_num_t.data<int>(),
                       sizeof(int));
  keep_index.Resize({keep_num});

  DenseTensor scores_filter, proposals_filter;
  // Handle the case when there is no keep index left
  if (keep_num == 0) {
    phi::funcs::SetConstant<phi::XPUContext, T> set_zero;
    proposals_filter.Resize(phi::make_ddim({1, 4}));
    dev_ctx.template Alloc<T>(&proposals_filter);
    scores_filter.Resize(phi::make_ddim({1, 1}));
    dev_ctx.template Alloc<T>(&scores_filter);
    set_zero(dev_ctx, &proposals_filter, static_cast<T>(0));
    set_zero(dev_ctx, &scores_filter, static_cast<T>(0));
    return std::make_pair(proposals_filter, scores_filter);
  }
  proposals_filter.Resize(phi::make_ddim({keep_num, 4}));
  dev_ctx.template Alloc<T>(&proposals_filter);
  scores_filter.Resize(phi::make_ddim({keep_num, 1}));
  dev_ctx.template Alloc<T>(&scores_filter);
  r = xpu::gather<T>(dev_ctx.x_context(),
                     proposals.data<T>(),
                     keep_index.data<int>(),
                     proposals_filter.data<T>(),
                     {pre_nms_num, 4},
                     keep_num,
                     0);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");

  r = xpu::gather<T>(dev_ctx.x_context(),
                     scores_sel.data<T>(),
                     keep_index.data<int>(),
                     scores_filter.data<T>(),
                     {pre_nms_num, 1},
                     keep_num,
                     0);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");

  if (nms_thresh <= 0) {
    if (dev_ctx.x_context()->xpu_stream) {
      dev_ctx.Wait();
    }
    return std::make_pair(proposals_filter, scores_filter);
  }

  // 4. nms
  int nms_keep_num = 0;
  r = xpu::sorted_nms<T>(dev_ctx.x_context(),
                         proposals_filter.data<T>(),
                         keep_index.data<int>(),
                         nms_keep_num,
                         keep_num,
                         nms_thresh,
                         pixel_offset);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "sorted_nms");
  if (post_nms_top_n > 0 && post_nms_top_n < nms_keep_num) {
    keep_index.Resize({post_nms_top_n});
  } else {
    keep_index.Resize({nms_keep_num});
  }

  DenseTensor scores_nms, proposals_nms;
  proposals_nms.Resize(phi::make_ddim({keep_index.numel(), 4}));
  dev_ctx.template Alloc<T>(&proposals_nms);
  scores_nms.Resize(phi::make_ddim({keep_index.numel(), 1}));
  dev_ctx.template Alloc<T>(&scores_nms);
  r = xpu::gather<T>(dev_ctx.x_context(),
                     proposals_filter.data<T>(),
                     keep_index.data<int>(),
                     proposals_nms.data<T>(),
                     {keep_num, 4},
                     keep_index.numel(),
                     0);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
  r = xpu::gather<T>(dev_ctx.x_context(),
                     scores_filter.data<T>(),
                     keep_index.data<int>(),
                     scores_nms.data<T>(),
                     {keep_num, 1},
                     keep_index.numel(),
                     0);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
  if (dev_ctx.x_context()->xpu_stream) {
    dev_ctx.Wait();
  }
  return std::make_pair(proposals_nms, scores_nms);
}

template <typename T, typename Context>
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
void GenerateProposalsKernel(const Context& dev_ctx,
                             const DenseTensor& scores,
                             const DenseTensor& bbox_deltas,
                             const DenseTensor& im_shape,
                             const DenseTensor& anchors,
                             const DenseTensor& variances,
                             int pre_nms_top_n,
                             int post_nms_top_n,
                             float nms_thresh,
                             float min_size,
                             float eta,
                             bool pixel_offset,
                             DenseTensor* rpn_rois,
                             DenseTensor* rpn_roi_probs,
                             DenseTensor* rpn_rois_num) {
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
  PADDLE_ENFORCE_GE(eta,
                    1.,
                    phi::errors::InvalidArgument(
                        "Not support adaptive NMS. The attribute 'eta' "
                        "should not less than 1. But received eta=[%d]",
                        eta));

  auto& scores_dim = scores.dims();
  // the shape of bbox score
  int num = scores_dim[0];
  int c_score = scores_dim[1];
  int h_score = scores_dim[2];
  int w_score = scores_dim[3];

  auto& bbox_dim = bbox_deltas.dims();
  int c_bbox = bbox_dim[1];
  int h_bbox = bbox_dim[2];
  int w_bbox = bbox_dim[3];

  DenseTensor bbox_deltas_swap, scores_swap;
  bbox_deltas_swap.Resize(phi::make_ddim({num, h_bbox, w_bbox, c_bbox}));
  dev_ctx.template Alloc<T>(&bbox_deltas_swap);

  scores_swap.Resize(phi::make_ddim({num, h_score, w_score, c_score}));
  dev_ctx.template Alloc<T>(&scores_swap);

  std::vector<int> axis = {0, 2, 3, 1};
  int r = xpu::transpose<T>(dev_ctx.x_context(),
                            bbox_deltas.data<T>(),
                            bbox_deltas_swap.data<T>(),
                            {num, c_bbox, h_bbox, w_bbox},
                            axis);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");

  r = xpu::transpose<T>(dev_ctx.x_context(),
                        scores.data<T>(),
                        scores_swap.data<T>(),
                        {num, c_score, h_score, w_score},
                        axis);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");

  DenseTensor tmp_anchors = anchors;
  DenseTensor tmp_variances = variances;
  tmp_anchors.Resize(phi::make_ddim({tmp_anchors.numel() / 4, 4}));
  tmp_variances.Resize(phi::make_ddim({tmp_variances.numel() / 4, 4}));

  // output
  rpn_rois->Resize(phi::make_ddim({bbox_deltas.numel() / 4, 4}));
  dev_ctx.template Alloc<T>(rpn_rois);

  rpn_roi_probs->Resize(phi::make_ddim({scores.numel(), 1}));
  dev_ctx.template Alloc<T>(rpn_roi_probs);

  auto place = dev_ctx.GetPlace();
  auto cpu_place = phi::CPUPlace();

  int num_proposals = 0;
  std::vector<size_t> offset(1, 0);
  std::vector<int> tmp_num;

  for (int64_t i = 0; i < num; ++i) {
    DenseTensor im_shape_slice = im_shape.Slice(i, i + 1);
    DenseTensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
    DenseTensor scores_slice = scores_swap.Slice(i, i + 1);

    bbox_deltas_slice.Resize(phi::make_ddim({h_bbox * w_bbox * c_bbox / 4, 4}));
    scores_slice.Resize(phi::make_ddim({h_score * w_score * c_score, 1}));

    std::pair<DenseTensor, DenseTensor> tensor_pair =
        ProposalForOneImage<T>(dev_ctx,
                               im_shape_slice,
                               tmp_anchors,
                               tmp_variances,
                               bbox_deltas_slice,
                               scores_slice,
                               pre_nms_top_n,
                               post_nms_top_n,
                               nms_thresh,
                               min_size,
                               eta,
                               pixel_offset);

    DenseTensor& proposals = tensor_pair.first;
    DenseTensor& nscores = tensor_pair.second;

375 376 377 378 379 380 381 382 383 384
    r = xpu::copy(dev_ctx.x_context(),
                  proposals.data<T>(),
                  rpn_rois->data<T>() + num_proposals * 4,
                  proposals.numel());
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
    r = xpu::copy(dev_ctx.x_context(),
                  nscores.data<T>(),
                  rpn_roi_probs->data<T>() + num_proposals,
                  nscores.numel());
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410

    if (dev_ctx.x_context()->xpu_stream) {
      dev_ctx.Wait();
    }
    num_proposals += proposals.dims()[0];
    offset.emplace_back(num_proposals);
    tmp_num.push_back(proposals.dims()[0]);
  }

  if (rpn_rois_num != nullptr) {
    rpn_rois_num->Resize(phi::make_ddim({num}));
    dev_ctx.template Alloc<int>(rpn_rois_num);
    int* num_data = rpn_rois_num->data<int>();
    paddle::memory::Copy(
        place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num);
  }

  phi::LoD lod;
  lod.emplace_back(offset);
  rpn_rois->set_lod(lod);
  rpn_roi_probs->set_lod(lod);
  rpn_rois->Resize(phi::make_ddim({num_proposals, 4}));
  rpn_roi_probs->Resize(phi::make_ddim({num_proposals, 1}));
}
}  // namespace phi

411 412
PD_REGISTER_KERNEL(
    generate_proposals, XPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {}