multiclass_nms_op.cc 25.6 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

7
http://www.apache.org/licenses/LICENSE-2.0
8

9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
limitations under the License. */

J
jerrywgz 已提交
14
#include <glog/logging.h>
15

16
#include "paddle/fluid/framework/infershape_utils.h"
Y
Yi Wang 已提交
17
#include "paddle/fluid/framework/op_registry.h"
18
#include "paddle/phi/infermeta/ternary.h"
Z
zhiboniu 已提交
19
#include "paddle/phi/kernels/funcs/detection/nms_util.h"
20 21 22 23

namespace paddle {
namespace operators {

24
using Tensor = phi::DenseTensor;
25
using LoDTensor = phi::DenseTensor;
26

27 28
inline std::vector<size_t> GetNmsLodFromRoisNum(
    const phi::DenseTensor* rois_num) {
29 30 31 32 33 34 35 36 37
  std::vector<size_t> rois_lod;
  auto* rois_num_data = rois_num->data<int>();
  rois_lod.push_back(static_cast<size_t>(0));
  for (int i = 0; i < rois_num->numel(); ++i) {
    rois_lod.push_back(rois_lod.back() + static_cast<size_t>(rois_num_data[i]));
  }
  return rois_lod;
}

D
dangqingqing 已提交
38
class MultiClassNMSOp : public framework::OperatorWithKernel {
39 40 41 42
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
X
xiaoting 已提交
43 44 45
    OP_INOUT_CHECK(ctx->HasInput("BBoxes"), "Input", "BBoxes", "MultiClassNMS");
    OP_INOUT_CHECK(ctx->HasInput("Scores"), "Input", "Scores", "MultiClassNMS");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MultiClassNMS");
D
dangqingqing 已提交
46
    auto box_dims = ctx->GetInputDim("BBoxes");
47
    auto score_dims = ctx->GetInputDim("Scores");
J
jerrywgz 已提交
48
    auto score_size = score_dims.size();
49

50
    if (ctx->IsRuntime()) {
51 52
      PADDLE_ENFORCE_EQ(score_size == 2 || score_size == 3,
                        true,
53 54 55 56
                        platform::errors::InvalidArgument(
                            "The rank of Input(Scores) must be 2 or 3"
                            ". But received rank = %d",
                            score_size));
57 58
      PADDLE_ENFORCE_EQ(box_dims.size(),
                        3,
X
xiaoting 已提交
59 60
                        platform::errors::InvalidArgument(
                            "The rank of Input(BBoxes) must be 3"
61
                            ". But received rank = %d",
X
xiaoting 已提交
62
                            box_dims.size()));
J
jerrywgz 已提交
63
      if (score_size == 3) {
64 65 66 67 68 69 70 71 72 73 74 75 76
        PADDLE_ENFORCE_EQ(box_dims[2] == 4 || box_dims[2] == 8 ||
                              box_dims[2] == 16 || box_dims[2] == 24 ||
                              box_dims[2] == 32,
                          true,
                          platform::errors::InvalidArgument(
                              "The last dimension of Input"
                              "(BBoxes) must be 4 or 8, "
                              "represents the layout of coordinate "
                              "[xmin, ymin, xmax, ymax] or "
                              "4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or "
                              "8 points: [xi, yi] i= 1,2,...,8 or "
                              "12 points: [xi, yi] i= 1,2,...,12 or "
                              "16 points: [xi, yi] i= 1,2,...,16"));
J
jerrywgz 已提交
77
        PADDLE_ENFORCE_EQ(
78 79
            box_dims[1],
            score_dims[2],
X
xiaoting 已提交
80 81 82 83 84
            platform::errors::InvalidArgument(
                "The 2nd dimension of Input(BBoxes) must be equal to "
                "last dimension of Input(Scores), which represents the "
                "predicted bboxes."
                "But received box_dims[1](%s) != socre_dims[2](%s)",
85 86
                box_dims[1],
                score_dims[2]));
J
jerrywgz 已提交
87
      } else {
88 89
        PADDLE_ENFORCE_EQ(box_dims[2],
                          4,
X
xiaoting 已提交
90
                          platform::errors::InvalidArgument(
91 92
                              "The last dimension of Input"
                              "(BBoxes) must be 4. But received dimension = %d",
X
xiaoting 已提交
93
                              box_dims[2]));
94
        PADDLE_ENFORCE_EQ(
95 96
            box_dims[1],
            score_dims[1],
97 98 99 100
            platform::errors::InvalidArgument(
                "The 2nd dimension of Input"
                "(BBoxes) must be equal to the 2nd dimension of Input(Scores). "
                "But received box dimension = %d, score dimension = %d",
101 102
                box_dims[1],
                score_dims[1]));
J
jerrywgz 已提交
103
      }
104
    }
105 106
    // Here the box_dims[0] is not the real dimension of output.
    // It will be rewritten in the computing kernel.
J
jerrywgz 已提交
107
    if (score_size == 3) {
108
      ctx->SetOutputDim("Out", {-1, box_dims[2] + 2});
J
jerrywgz 已提交
109 110 111
    } else {
      ctx->SetOutputDim("Out", {-1, box_dims[2] + 2});
    }
112 113 114
    if (!ctx->IsRuntime()) {
      ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1));
    }
115
  }
D
dangqingqing 已提交
116 117 118 119 120

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(
121
        OperatorWithKernel::IndicateVarDataType(ctx, "Scores"),
122
        platform::CPUPlace());
D
dangqingqing 已提交
123
  }
124 125
};

126 127
template <class T>
void SliceOneClass(const platform::DeviceContext& ctx,
128
                   const phi::DenseTensor& items,
129
                   const int class_id,
130
                   phi::DenseTensor* one_class_item) {
131 132 133 134 135
  T* item_data = one_class_item->mutable_data<T>(ctx.GetPlace());
  const T* items_data = items.data<T>();
  const int64_t num_item = items.dims()[0];
  const int class_num = items.dims()[1];
  if (items.dims().size() == 3) {
J
jerrywgz 已提交
136 137 138 139 140 141 142 143 144 145
    int item_size = items.dims()[2];
    for (int i = 0; i < num_item; ++i) {
      std::memcpy(item_data + i * item_size,
                  items_data + i * class_num * item_size + class_id * item_size,
                  sizeof(T) * item_size);
    }
  } else {
    for (int i = 0; i < num_item; ++i) {
      item_data[i] = items_data[i * class_num + class_id];
    }
146 147 148
  }
}

149
template <typename T>
D
dangqingqing 已提交
150
class MultiClassNMSKernel : public framework::OpKernel<T> {
151
 public:
152 153
  void NMSFast(const phi::DenseTensor& bbox,
               const phi::DenseTensor& scores,
154 155 156 157 158
               const T score_threshold,
               const T nms_threshold,
               const T eta,
               const int64_t top_k,
               std::vector<int>* selected_indices,
J
jerrywgz 已提交
159
               const bool normalized) const {
160 161 162
    // The total boxes for each instance.
    int64_t num_boxes = bbox.dims()[0];
    // 4: [xmin ymin xmax ymax]
Y
Yipeng 已提交
163 164
    // 8: [x1 y1 x2 y2 x3 y3 x4 y4]
    // 16, 24, or 32: [x1 y1 x2 y2 ...  xn yn], n = 8, 12 or 16
165 166 167 168 169
    int64_t box_size = bbox.dims()[1];

    std::vector<T> scores_data(num_boxes);
    std::copy_n(scores.data<T>(), num_boxes, scores_data.begin());
    std::vector<std::pair<T, int>> sorted_indices;
Z
zhiboniu 已提交
170 171
    phi::funcs::GetMaxScoreIndex(
        scores_data, score_threshold, top_k, &sorted_indices);
172 173 174 175 176 177 178 179

    selected_indices->clear();
    T adaptive_threshold = nms_threshold;
    const T* bbox_data = bbox.data<T>();

    while (sorted_indices.size() != 0) {
      const int idx = sorted_indices.front().second;
      bool keep = true;
180
      for (size_t k = 0; k < selected_indices->size(); ++k) {
181 182
        if (keep) {
          const int kept_idx = (*selected_indices)[k];
Y
Yipeng 已提交
183 184 185
          T overlap = T(0.);
          // 4: [xmin ymin xmax ymax]
          if (box_size == 4) {
Z
zhiboniu 已提交
186 187 188 189
            overlap =
                phi::funcs::JaccardOverlap<T>(bbox_data + idx * box_size,
                                              bbox_data + kept_idx * box_size,
                                              normalized);
Y
Yipeng 已提交
190 191 192 193
          }
          // 8: [x1 y1 x2 y2 x3 y3 x4 y4] or 16, 24, 32
          if (box_size == 8 || box_size == 16 || box_size == 24 ||
              box_size == 32) {
Z
zhiboniu 已提交
194 195 196 197
            overlap = phi::funcs::PolyIoU<T>(bbox_data + idx * box_size,
                                             bbox_data + kept_idx * box_size,
                                             box_size,
                                             normalized);
Y
Yipeng 已提交
198
          }
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
          keep = overlap <= adaptive_threshold;
        } else {
          break;
        }
      }
      if (keep) {
        selected_indices->push_back(idx);
      }
      sorted_indices.erase(sorted_indices.begin());
      if (keep && eta < 1 && adaptive_threshold > 0.5) {
        adaptive_threshold *= eta;
      }
    }
  }

D
dangqingqing 已提交
214
  void MultiClassNMS(const framework::ExecutionContext& ctx,
215 216
                     const phi::DenseTensor& scores,
                     const phi::DenseTensor& bboxes,
J
jerrywgz 已提交
217
                     const int scores_size,
218 219
                     std::map<int, std::vector<int>>* indices,
                     int* num_nmsed_out) const {
D
dangqingqing 已提交
220 221 222
    int64_t background_label = ctx.Attr<int>("background_label");
    int64_t nms_top_k = ctx.Attr<int>("nms_top_k");
    int64_t keep_top_k = ctx.Attr<int>("keep_top_k");
J
jerrywgz 已提交
223
    bool normalized = ctx.Attr<bool>("normalized");
224 225
    T nms_threshold = static_cast<T>(ctx.Attr<float>("nms_threshold"));
    T nms_eta = static_cast<T>(ctx.Attr<float>("nms_eta"));
D
dangqingqing 已提交
226
    T score_threshold = static_cast<T>(ctx.Attr<float>("score_threshold"));
L
Leo Chen 已提交
227
    auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
228 229

    int num_det = 0;
230 231 232 233 234 235 236 237 238 239 240 241 242

    int64_t class_num = scores_size == 3 ? scores.dims()[0] : scores.dims()[1];
    Tensor bbox_slice, score_slice;
    for (int64_t c = 0; c < class_num; ++c) {
      if (c == background_label) continue;
      if (scores_size == 3) {
        score_slice = scores.Slice(c, c + 1);
        bbox_slice = bboxes;
      } else {
        score_slice.Resize({scores.dims()[0], 1});
        bbox_slice.Resize({scores.dims()[0], 4});
        SliceOneClass<T>(dev_ctx, scores, c, &score_slice);
        SliceOneClass<T>(dev_ctx, bboxes, c, &bbox_slice);
J
jerrywgz 已提交
243
      }
244 245 246 247 248 249 250 251
      NMSFast(bbox_slice,
              score_slice,
              score_threshold,
              nms_threshold,
              nms_eta,
              nms_top_k,
              &((*indices)[c]),
              normalized);
252
      if (scores_size == 2) {
J
jerrywgz 已提交
253 254
        std::stable_sort((*indices)[c].begin(), (*indices)[c].end());
      }
255
      num_det += (*indices)[c].size();
256 257
    }

258
    *num_nmsed_out = num_det;
259 260
    const T* scores_data = scores.data<T>();
    if (keep_top_k > -1 && num_det > keep_top_k) {
J
jerrywgz 已提交
261
      const T* sdata;
262
      std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
263
      for (const auto& it : *indices) {
264
        int label = it.first;
J
jerrywgz 已提交
265
        if (scores_size == 3) {
266
          sdata = scores_data + label * scores.dims()[1];
J
jerrywgz 已提交
267
        } else {
268 269 270
          score_slice.Resize({scores.dims()[0], 1});
          SliceOneClass<T>(dev_ctx, scores, label, &score_slice);
          sdata = score_slice.data<T>();
J
jerrywgz 已提交
271
        }
272
        const std::vector<int>& label_indices = it.second;
273
        for (size_t j = 0; j < label_indices.size(); ++j) {
274 275 276 277 278 279
          int idx = label_indices[j];
          score_index_pairs.push_back(
              std::make_pair(sdata[idx], std::make_pair(label, idx)));
        }
      }
      // Keep top k results per image.
280 281
      std::stable_sort(score_index_pairs.begin(),
                       score_index_pairs.end(),
Z
zhiboniu 已提交
282
                       phi::funcs::SortScorePairDescend<std::pair<int, int>>);
283 284 285 286
      score_index_pairs.resize(keep_top_k);

      // Store the new indices.
      std::map<int, std::vector<int>> new_indices;
287
      for (size_t j = 0; j < score_index_pairs.size(); ++j) {
288 289 290 291
        int label = score_index_pairs[j].second.first;
        int idx = score_index_pairs[j].second.second;
        new_indices[label].push_back(idx);
      }
J
jerrywgz 已提交
292 293 294 295 296 297 298
      if (scores_size == 2) {
        for (const auto& it : new_indices) {
          int label = it.first;
          std::stable_sort(new_indices[label].begin(),
                           new_indices[label].end());
        }
      }
299 300
      new_indices.swap(*indices);
      *num_nmsed_out = keep_top_k;
301 302 303
    }
  }

J
jerrywgz 已提交
304
  void MultiClassOutput(const platform::DeviceContext& ctx,
305 306
                        const phi::DenseTensor& scores,
                        const phi::DenseTensor& bboxes,
307
                        const std::map<int, std::vector<int>>& selected_indices,
308
                        const int scores_size,
309
                        phi::DenseTensor* outs,
310 311
                        int* oindices = nullptr,
                        const int offset = 0) const {
J
jerrywgz 已提交
312
    int64_t class_num = scores.dims()[1];
Y
Yipeng 已提交
313 314
    int64_t predict_dim = scores.dims()[1];
    int64_t box_size = bboxes.dims()[1];
J
jerrywgz 已提交
315 316 317 318
    if (scores_size == 2) {
      box_size = bboxes.dims()[2];
    }
    int64_t out_dim = box_size + 2;
319 320 321
    auto* scores_data = scores.data<T>();
    auto* bboxes_data = bboxes.data<T>();
    auto* odata = outs->data<T>();
J
jerrywgz 已提交
322 323 324
    const T* sdata;
    Tensor bbox;
    bbox.Resize({scores.dims()[0], box_size});
325 326 327
    int count = 0;
    for (const auto& it : selected_indices) {
      int label = it.first;
D
dangqingqing 已提交
328
      const std::vector<int>& indices = it.second;
J
jerrywgz 已提交
329 330 331 332 333
      if (scores_size == 2) {
        SliceOneClass<T>(ctx, bboxes, label, &bbox);
      } else {
        sdata = scores_data + label * predict_dim;
      }
334

335
      for (size_t j = 0; j < indices.size(); ++j) {
336
        int idx = indices[j];
J
jerrywgz 已提交
337 338 339 340 341
        odata[count * out_dim] = label;  // label
        const T* bdata;
        if (scores_size == 3) {
          bdata = bboxes_data + idx * box_size;
          odata[count * out_dim + 1] = sdata[idx];  // score
342 343 344
          if (oindices != nullptr) {
            oindices[count] = offset + idx;
          }
J
jerrywgz 已提交
345 346 347
        } else {
          bdata = bbox.data<T>() + idx * box_size;
          odata[count * out_dim + 1] = *(scores_data + idx * class_num + label);
348 349 350
          if (oindices != nullptr) {
            oindices[count] = offset + idx * class_num + label;
          }
J
jerrywgz 已提交
351
        }
Y
Yipeng 已提交
352 353
        // xmin, ymin, xmax, ymax or multi-points coordinates
        std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
D
dangqingqing 已提交
354
        count++;
355 356 357 358 359
      }
    }
  }

  void Compute(const framework::ExecutionContext& ctx) const override {
J
jerrywgz 已提交
360 361
    auto* boxes = ctx.Input<LoDTensor>("BBoxes");
    auto* scores = ctx.Input<LoDTensor>("Scores");
362
    auto* outs = ctx.Output<LoDTensor>("Out");
363 364
    bool return_index = ctx.HasOutput("Index") ? true : false;
    auto index = ctx.Output<LoDTensor>("Index");
365
    bool has_roisnum = ctx.HasInput("RoisNum") ? true : false;
366
    auto rois_num = ctx.Input<phi::DenseTensor>("RoisNum");
367
    auto score_dims = scores->dims();
368
    auto score_size = score_dims.size();
L
Leo Chen 已提交
369
    auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
370 371 372

    std::vector<std::map<int, std::vector<int>>> all_indices;
    std::vector<size_t> batch_starts = {0};
J
jerrywgz 已提交
373 374 375 376
    int64_t batch_size = score_dims[0];
    int64_t box_dim = boxes->dims()[2];
    int64_t out_dim = box_dim + 2;
    int num_nmsed_out = 0;
377
    Tensor boxes_slice, scores_slice;
378 379 380 381 382 383
    int n = 0;
    if (has_roisnum) {
      n = score_size == 3 ? batch_size : rois_num->numel();
    } else {
      n = score_size == 3 ? batch_size : boxes->lod().back().size() - 1;
    }
384
    for (int i = 0; i < n; ++i) {
385
      std::map<int, std::vector<int>> indices;
386 387 388 389 390 391
      if (score_size == 3) {
        scores_slice = scores->Slice(i, i + 1);
        scores_slice.Resize({score_dims[1], score_dims[2]});
        boxes_slice = boxes->Slice(i, i + 1);
        boxes_slice.Resize({score_dims[2], box_dim});
      } else {
392 393 394 395 396 397
        std::vector<size_t> boxes_lod;
        if (has_roisnum) {
          boxes_lod = GetNmsLodFromRoisNum(rois_num);
        } else {
          boxes_lod = boxes->lod().back();
        }
398 399 400 401 402
        if (boxes_lod[i] == boxes_lod[i + 1]) {
          all_indices.push_back(indices);
          batch_starts.push_back(batch_starts.back());
          continue;
        }
403 404
        scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]);
        boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]);
J
jerrywgz 已提交
405
      }
406 407
      MultiClassNMS(
          ctx, scores_slice, boxes_slice, score_size, &indices, &num_nmsed_out);
408 409
      all_indices.push_back(indices);
      batch_starts.push_back(batch_starts.back() + num_nmsed_out);
J
jerrywgz 已提交
410 411 412 413
    }

    int num_kept = batch_starts.back();
    if (num_kept == 0) {
414 415 416 417 418 419 420 421
      if (return_index) {
        outs->mutable_data<T>({0, out_dim}, ctx.GetPlace());
        index->mutable_data<int>({0, 1}, ctx.GetPlace());
      } else {
        T* od = outs->mutable_data<T>({1, 1}, ctx.GetPlace());
        od[0] = -1;
        batch_starts = {0, 1};
      }
J
jerrywgz 已提交
422 423
    } else {
      outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace());
424 425
      int offset = 0;
      int* oindices = nullptr;
426 427 428 429 430 431
      for (int i = 0; i < n; ++i) {
        if (score_size == 3) {
          scores_slice = scores->Slice(i, i + 1);
          boxes_slice = boxes->Slice(i, i + 1);
          scores_slice.Resize({score_dims[1], score_dims[2]});
          boxes_slice.Resize({score_dims[2], box_dim});
432 433 434
          if (return_index) {
            offset = i * score_dims[2];
          }
435
        } else {
436 437 438 439 440 441
          std::vector<size_t> boxes_lod;
          if (has_roisnum) {
            boxes_lod = GetNmsLodFromRoisNum(rois_num);
          } else {
            boxes_lod = boxes->lod().back();
          }
442
          if (boxes_lod[i] == boxes_lod[i + 1]) continue;
443 444
          scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]);
          boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]);
445 446 447
          if (return_index) {
            offset = boxes_lod[i] * score_dims[1];
          }
J
jerrywgz 已提交
448
        }
449

450 451 452 453
        int64_t s = batch_starts[i];
        int64_t e = batch_starts[i + 1];
        if (e > s) {
          Tensor out = outs->Slice(s, e);
454 455 456 457 458
          if (return_index) {
            int* output_idx =
                index->mutable_data<int>({num_kept, 1}, ctx.GetPlace());
            oindices = output_idx + s;
          }
459 460 461 462 463 464 465 466
          MultiClassOutput(dev_ctx,
                           scores_slice,
                           boxes_slice,
                           all_indices[i],
                           score_dims.size(),
                           &out,
                           oindices,
                           offset);
467 468 469
        }
      }
    }
470
    if (ctx.HasOutput("NmsRoisNum")) {
471
      auto* nms_rois_num = ctx.Output<phi::DenseTensor>("NmsRoisNum");
472 473 474 475 476 477 478
      nms_rois_num->mutable_data<int>({n}, ctx.GetPlace());
      int* num_data = nms_rois_num->data<int>();
      for (int i = 1; i <= n; i++) {
        num_data[i - 1] = batch_starts[i] - batch_starts[i - 1];
      }
      nms_rois_num->Resize({n});
    }
479 480 481

    framework::LoD lod;
    lod.emplace_back(batch_starts);
482 483 484
    if (return_index) {
      index->set_lod(lod);
    }
485 486 487 488
    outs->set_lod(lod);
  }
};

D
dangqingqing 已提交
489
class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
490
 public:
Y
Yu Yang 已提交
491
  void Make() override {
D
dangqingqing 已提交
492
    AddInput("BBoxes",
J
jerrywgz 已提交
493 494
             "Two types of bboxes are supported:"
             "1. (Tensor) A 3-D Tensor with shape "
Y
Yipeng 已提交
495
             "[N, M, 4 or 8 16 24 32] represents the "
496 497
             "predicted locations of M bounding bboxes, N is the batch size. "
             "Each bounding box has four coordinate values and the layout is "
J
jerrywgz 已提交
498
             "[xmin, ymin, xmax, ymax], when box size equals to 4."
499 500
             "2. (LoDTensor) A 3-D Tensor with shape [M, C, 4]"
             "M is the number of bounding boxes, C is the class number");
D
dangqingqing 已提交
501
    AddInput("Scores",
J
jerrywgz 已提交
502 503
             "Two types of scores are supported:"
             "1. (Tensor) A 3-D Tensor with shape [N, C, M] represents the "
D
dangqingqing 已提交
504 505 506
             "predicted confidence predictions. N is the batch size, C is the "
             "class number, M is number of bounding boxes. For each category "
             "there are total M scores which corresponding M bounding boxes. "
507 508 509 510
             " Please note, M is equal to the 2nd dimension of BBoxes. "
             "2. (LoDTensor) A 2-D LoDTensor with shape [M, C]. "
             "M is the number of bbox, C is the class number. In this case, "
             "Input BBoxes should be the second case with shape [M, C, 4].");
D
dangqingqing 已提交
511
    AddAttr<int>(
512
        "background_label",
翟飞跃 已提交
513
        "(int, default: 0) "
D
dangqingqing 已提交
514 515
        "The index of background label, the background label will be ignored. "
        "If set to -1, then all categories will be considered.")
516
        .SetDefault(0);
D
dangqingqing 已提交
517 518
    AddAttr<float>("score_threshold",
                   "(float) "
D
dangqingqing 已提交
519 520
                   "Threshold to filter out bounding boxes with low "
                   "confidence score. If not provided, consider all boxes.");
D
dangqingqing 已提交
521 522 523
    AddAttr<int>("nms_top_k",
                 "(int64_t) "
                 "Maximum number of detections to be kept according to the "
T
tianshuo78520a 已提交
524
                 "confidences after the filtering detections based on "
D
dangqingqing 已提交
525
                 "score_threshold");
526
    AddAttr<float>("nms_threshold",
翟飞跃 已提交
527
                   "(float, default: 0.3) "
D
dangqingqing 已提交
528
                   "The threshold to be used in NMS.")
529 530 531
        .SetDefault(0.3);
    AddAttr<float>("nms_eta",
                   "(float) "
D
dangqingqing 已提交
532
                   "The parameter for adaptive NMS.")
533
        .SetDefault(1.0);
D
dangqingqing 已提交
534 535 536 537
    AddAttr<int>("keep_top_k",
                 "(int64_t) "
                 "Number of total bboxes to be kept per image after NMS "
                 "step. -1 means keeping all bboxes after NMS step.");
J
jerrywgz 已提交
538
    AddAttr<bool>("normalized",
J
jerrywgz 已提交
539
                  "(bool, default true) "
J
jerrywgz 已提交
540 541
                  "Whether detections are normalized.")
        .SetDefault(true);
542 543 544
    AddOutput("Out",
              "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the "
              "detections. Each row has 6 values: "
Y
Yipeng 已提交
545 546 547 548 549 550
              "[label, confidence, xmin, ymin, xmax, ymax] or "
              "(LoDTensor) A 2-D LoDTensor with shape [No, 10] represents the "
              "detections. Each row has 10 values: "
              "[label, confidence, x1, y1, x2, y2, x3, y3, x4, y4]. No is the "
              "total number of detections in this mini-batch."
              "For each instance, "
551 552 553 554
              "the offsets in first dimension are called LoD, the number of "
              "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
              "no detected bbox.");
    AddComment(R"DOC(
D
dangqingqing 已提交
555
This operator is to do multi-class non maximum suppression (NMS) on a batched
556
of boxes and scores.
D
dangqingqing 已提交
557 558 559 560 561 562
In the NMS step, this operator greedily selects a subset of detection bounding
boxes that have high scores larger than score_threshold, if providing this
threshold, then selects the largest nms_top_k confidences scores if nms_top_k
is larger than -1. Then this operator pruns away boxes that have high IOU
(intersection over union) overlap with already selected boxes by adaptive
threshold NMS based on parameters of nms_threshold and nms_eta.
563
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
D
dangqingqing 已提交
564 565
per image if keep_top_k is larger than -1.
This operator support multi-class and batched inputs. It applying NMS
566 567 568
independently for each class. The outputs is a 2-D LoDTenosr, for each
image, the offsets in first dimension of LoDTensor are called LoD, the number
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
569
means there is no detected bbox for this image.
570 571 572 573
)DOC");
  }
};

574 575 576 577 578 579 580 581 582 583 584 585 586 587
class MultiClassNMS2Op : public MultiClassNMSOp {
 public:
  MultiClassNMS2Op(const std::string& type,
                   const framework::VariableNameMap& inputs,
                   const framework::VariableNameMap& outputs,
                   const framework::AttributeMap& attrs)
      : MultiClassNMSOp(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext* ctx) const override {
    MultiClassNMSOp::InferShape(ctx);

    auto score_dims = ctx->GetInputDim("Scores");
    auto score_size = score_dims.size();
    if (score_size == 3) {
588
      ctx->SetOutputDim("Index", {-1, 1});
589 590 591
    } else {
      ctx->SetOutputDim("Index", {-1, 1});
    }
592 593 594
    if (!ctx->IsRuntime()) {
      ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1));
    }
595 596 597 598 599 600 601 602 603 604 605 606 607 608 609
  }
};

class MultiClassNMS2OpMaker : public MultiClassNMSOpMaker {
 public:
  void Make() override {
    MultiClassNMSOpMaker::Make();
    AddOutput("Index",
              "(LoDTensor) A 2-D LoDTensor with shape [No, 1] represents the "
              "index of selected bbox. The index is the absolute index cross "
              "batches.")
        .AsIntermediate();
  }
};

610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631
class MultiClassNMS3Op : public MultiClassNMS2Op {
 public:
  MultiClassNMS3Op(const std::string& type,
                   const framework::VariableNameMap& inputs,
                   const framework::VariableNameMap& outputs,
                   const framework::AttributeMap& attrs)
      : MultiClassNMS2Op(type, inputs, outputs, attrs) {}
};

class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
 public:
  void Make() override {
    MultiClassNMS2OpMaker::Make();
    AddInput("RoisNum",
             "(Tensor) The number of RoIs in shape (B),"
             "B is the number of images")
        .AsDispensable();
    AddOutput("NmsRoisNum", "(Tensor), The number of NMS RoIs in each image")
        .AsDispensable();
  }
};

632 633 634
}  // namespace operators
}  // namespace paddle

635 636 637 638
DECLARE_INFER_SHAPE_FUNCTOR(multiclass_nms3,
                            MultiClassNMSShapeFunctor,
                            PD_INFER_META(phi::MultiClassNMSInferMeta));

639
namespace ops = paddle::operators;
H
hong 已提交
640
REGISTER_OPERATOR(
641 642 643
    multiclass_nms,
    ops::MultiClassNMSOp,
    ops::MultiClassNMSOpMaker,
H
hong 已提交
644 645
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
646 647
REGISTER_OP_CPU_KERNEL(multiclass_nms,
                       ops::MultiClassNMSKernel<float>,
D
dangqingqing 已提交
648
                       ops::MultiClassNMSKernel<double>);
H
hong 已提交
649
REGISTER_OPERATOR(
650 651 652
    multiclass_nms2,
    ops::MultiClassNMS2Op,
    ops::MultiClassNMS2OpMaker,
H
hong 已提交
653 654
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
655 656
REGISTER_OP_CPU_KERNEL(multiclass_nms2,
                       ops::MultiClassNMSKernel<float>,
657
                       ops::MultiClassNMSKernel<double>);
658 659

REGISTER_OPERATOR(
660 661 662
    multiclass_nms3,
    ops::MultiClassNMS3Op,
    ops::MultiClassNMS3OpMaker,
663
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
664 665
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
    MultiClassNMSShapeFunctor);