multiclass_nms_op.cc 14.9 KB
Newer Older
D
dangqingqing 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

constexpr int64_t kOutputDim = 6;
constexpr int64_t kBBoxSize = 4;

D
dangqingqing 已提交
27
class MultiClassNMSOp : public framework::OperatorWithKernel {
28 29 30 31
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
D
dangqingqing 已提交
32 33
    PADDLE_ENFORCE(ctx->HasInput("BBoxes"),
                   "Input(BBoxes) of MultiClassNMS should not be null.");
34
    PADDLE_ENFORCE(ctx->HasInput("Scores"),
D
dangqingqing 已提交
35 36 37
                   "Input(Scores) of MultiClassNMS should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of MultiClassNMS should not be null.");
38

D
dangqingqing 已提交
39
    auto box_dims = ctx->GetInputDim("BBoxes");
40 41
    auto score_dims = ctx->GetInputDim("Scores");

D
dangqingqing 已提交
42
    PADDLE_ENFORCE_EQ(box_dims.size(), 2,
D
dangqingqing 已提交
43
                      "The rank of Input(BBoxes) must be 2.");
44 45
    PADDLE_ENFORCE_EQ(score_dims.size(), 3,
                      "The rank of Input(Scores) must be 3.");
D
dangqingqing 已提交
46 47 48 49 50 51 52 53
    PADDLE_ENFORCE_EQ(box_dims[1], 4,
                      "The 2nd dimension of Input(BBoxes) must be 4, "
                      "represents the layout of coordinate "
                      "[xmin, ymin, xmax, ymax]");
    PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2],
                      "The 1st dimensiong of Input(BBoxes) must be equal to "
                      "3rd dimension of Input(Scores), which represents the "
                      "predicted bboxes.");
54 55 56 57 58

    // Here the box_dims[0] is not the real dimension of output.
    // It will be rewritten in the computing kernel.
    ctx->SetOutputDim("Out", {box_dims[0], 6});
  }
D
dangqingqing 已提交
59 60 61 62 63 64 65 66 67

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(
        framework::ToDataType(
            ctx.Input<framework::LoDTensor>("Scores")->type()),
        ctx.device_context());
  }
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
};

template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
                          const std::pair<float, T>& pair2) {
  return pair1.first > pair2.first;
}

template <class T>
static inline void GetMaxScoreIndex(
    const std::vector<T>& scores, const T threshold, int top_k,
    std::vector<std::pair<T, int>>* sorted_indices) {
  for (size_t i = 0; i < scores.size(); ++i) {
    if (scores[i] > threshold) {
      sorted_indices->push_back(std::make_pair(scores[i], i));
    }
  }
  // Sort the score pair according to the scores in descending order
  std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
                   SortScorePairDescend<int>);
  // Keep top_k scores if needed.
  if (top_k > -1 && top_k < sorted_indices->size()) {
    sorted_indices->resize(top_k);
  }
}

template <class T>
T BBoxArea(const T* box, const bool normalized) {
  if (box[2] < box[0] || box[3] < box[1]) {
D
dangqingqing 已提交
97 98 99
    // If coordinate values are is invalid
    // (e.g. xmax < xmin or ymax < ymin), return 0.
    return static_cast<T>(0.);
100 101 102 103 104 105
  } else {
    const T w = box[2] - box[0];
    const T h = box[3] - box[1];
    if (normalized) {
      return w * h;
    } else {
D
dangqingqing 已提交
106
      // If coordinate values are not within range [0, 1].
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
      return (w + 1) * (h + 1);
    }
  }
}

template <class T>
static inline T JaccardOverlap(const T* box1, const T* box2,
                               const bool normalized) {
  if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
      box2[3] < box1[1]) {
    return static_cast<T>(0.);
  } else {
    const T inter_xmin = std::max(box1[0], box2[0]);
    const T inter_ymin = std::max(box1[1], box2[1]);
    const T inter_xmax = std::min(box1[2], box2[2]);
    const T inter_ymax = std::min(box1[3], box2[3]);
    const T inter_w = inter_xmax - inter_xmin;
    const T inter_h = inter_ymax - inter_ymin;
    const T inter_area = inter_w * inter_h;
    const T bbox1_area = BBoxArea<T>(box1, normalized);
    const T bbox2_area = BBoxArea<T>(box2, normalized);
    return inter_area / (bbox1_area + bbox2_area - inter_area);
  }
}

template <typename T>
D
dangqingqing 已提交
133
class MultiClassNMSKernel : public framework::OpKernel<T> {
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
 public:
  void NMSFast(const Tensor& bbox, const Tensor& scores,
               const T score_threshold, const T nms_threshold, const T eta,
               const int64_t top_k, std::vector<int>* selected_indices) const {
    // The total boxes for each instance.
    int64_t num_boxes = bbox.dims()[0];
    // 4: [xmin ymin xmax ymax]
    int64_t box_size = bbox.dims()[1];

    std::vector<T> scores_data(num_boxes);
    std::copy_n(scores.data<T>(), num_boxes, scores_data.begin());
    std::vector<std::pair<T, int>> sorted_indices;
    GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices);

    selected_indices->clear();
    T adaptive_threshold = nms_threshold;
    const T* bbox_data = bbox.data<T>();

    while (sorted_indices.size() != 0) {
      const int idx = sorted_indices.front().second;
      bool keep = true;
      for (int k = 0; k < selected_indices->size(); ++k) {
        if (keep) {
          const int kept_idx = (*selected_indices)[k];
          T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
                                        bbox_data + kept_idx * box_size, true);
          keep = overlap <= adaptive_threshold;
        } else {
          break;
        }
      }
      if (keep) {
        selected_indices->push_back(idx);
      }
      sorted_indices.erase(sorted_indices.begin());
      if (keep && eta < 1 && adaptive_threshold > 0.5) {
        adaptive_threshold *= eta;
      }
    }
  }

D
dangqingqing 已提交
175
  void MultiClassNMS(const framework::ExecutionContext& ctx,
176
                     const Tensor& scores, const Tensor& bboxes,
D
dangqingqing 已提交
177 178
                     std::map<int, std::vector<int>>& indices,
                     int& num_nmsed_out) const {
D
dangqingqing 已提交
179 180 181
    int64_t background_label = ctx.Attr<int>("background_label");
    int64_t nms_top_k = ctx.Attr<int>("nms_top_k");
    int64_t keep_top_k = ctx.Attr<int>("keep_top_k");
182 183
    T nms_threshold = static_cast<T>(ctx.Attr<float>("nms_threshold"));
    T nms_eta = static_cast<T>(ctx.Attr<float>("nms_eta"));
D
dangqingqing 已提交
184
    T score_threshold = static_cast<T>(ctx.Attr<float>("score_threshold"));
185 186 187 188 189 190 191 192

    int64_t class_num = scores.dims()[0];
    int64_t predict_dim = scores.dims()[1];
    int num_det = 0;
    for (int64_t c = 0; c < class_num; ++c) {
      if (c == background_label) continue;
      Tensor score = scores.Slice(c, c + 1);
      NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k,
D
dangqingqing 已提交
193 194
              &(indices[c]));
      num_det += indices[c].size();
195 196
    }

D
dangqingqing 已提交
197
    num_nmsed_out = num_det;
198 199 200
    const T* scores_data = scores.data<T>();
    if (keep_top_k > -1 && num_det > keep_top_k) {
      std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
D
dangqingqing 已提交
201
      for (const auto& it : indices) {
202 203 204 205 206 207 208 209 210 211 212
        int label = it.first;
        const T* sdata = scores_data + label * predict_dim;
        const std::vector<int>& label_indices = it.second;
        for (int j = 0; j < label_indices.size(); ++j) {
          int idx = label_indices[j];
          PADDLE_ENFORCE_LT(idx, predict_dim);
          score_index_pairs.push_back(
              std::make_pair(sdata[idx], std::make_pair(label, idx)));
        }
      }
      // Keep top k results per image.
213 214
      std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
                       SortScorePairDescend<std::pair<int, int>>);
215 216 217 218 219 220 221 222 223
      score_index_pairs.resize(keep_top_k);

      // Store the new indices.
      std::map<int, std::vector<int>> new_indices;
      for (int j = 0; j < score_index_pairs.size(); ++j) {
        int label = score_index_pairs[j].second.first;
        int idx = score_index_pairs[j].second.second;
        new_indices[label].push_back(idx);
      }
D
dangqingqing 已提交
224 225
      new_indices.swap(indices);
      num_nmsed_out = keep_top_k;
226 227 228
    }
  }

D
dangqingqing 已提交
229
  void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
230 231 232 233 234 235 236 237 238 239 240
                        std::map<int, std::vector<int>>& selected_indices,
                        Tensor* outs) const {
    int predict_dim = scores.dims()[1];
    auto* scores_data = scores.data<T>();
    auto* bboxes_data = bboxes.data<T>();
    auto* odata = outs->data<T>();

    int count = 0;
    for (const auto& it : selected_indices) {
      int label = it.first;
      const T* sdata = scores_data + label * predict_dim;
D
dangqingqing 已提交
241
      const std::vector<int>& indices = it.second;
242 243 244 245 246
      for (int j = 0; j < indices.size(); ++j) {
        int idx = indices[j];
        const T* bdata = bboxes_data + idx * kBBoxSize;
        odata[count * kOutputDim] = label;           // label
        odata[count * kOutputDim + 1] = sdata[idx];  // score
D
dangqingqing 已提交
247 248
        // xmin, ymin, xmax, ymax
        std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T));
D
dangqingqing 已提交
249
        count++;
250 251 252 253 254
      }
    }
  }

  void Compute(const framework::ExecutionContext& ctx) const override {
D
dangqingqing 已提交
255
    auto* boxes = ctx.Input<Tensor>("BBoxes");
256 257 258 259 260
    auto* scores = ctx.Input<Tensor>("Scores");
    auto* outs = ctx.Output<LoDTensor>("Out");

    auto score_dims = scores->dims();

D
dangqingqing 已提交
261
    int64_t batch_size = score_dims[0];
262 263 264 265 266 267 268 269 270 271
    int64_t class_num = score_dims[1];
    int64_t predict_dim = score_dims[2];

    std::vector<std::map<int, std::vector<int>>> all_indices;
    std::vector<size_t> batch_starts = {0};
    for (int64_t i = 0; i < batch_size; ++i) {
      Tensor ins_score = scores->Slice(i, i + 1);
      ins_score.Resize({class_num, predict_dim});
      std::map<int, std::vector<int>> indices;
      int num_nmsed_out = 0;
D
dangqingqing 已提交
272
      MultiClassNMS(ctx, ins_score, *boxes, indices, num_nmsed_out);
273 274 275 276 277 278
      all_indices.push_back(indices);
      batch_starts.push_back(batch_starts.back() + num_nmsed_out);
    }

    int num_kept = batch_starts.back();
    if (num_kept == 0) {
279 280
      T* od = outs->mutable_data<T>({1}, ctx.GetPlace());
      od[0] = -1;
281 282 283 284 285 286 287 288 289
    } else {
      outs->mutable_data<T>({num_kept, kOutputDim}, ctx.GetPlace());
      for (int64_t i = 0; i < batch_size; ++i) {
        Tensor ins_score = scores->Slice(i, i + 1);
        ins_score.Resize({class_num, predict_dim});
        int64_t s = batch_starts[i];
        int64_t e = batch_starts[i + 1];
        if (e > s) {
          Tensor out = outs->Slice(s, e);
D
dangqingqing 已提交
290
          MultiClassOutput(ins_score, *boxes, all_indices[i], &out);
291 292 293 294 295 296 297 298 299 300 301
        }
      }
    }

    framework::LoD lod;
    lod.emplace_back(batch_starts);

    outs->set_lod(lod);
  }
};

D
dangqingqing 已提交
302
class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
303
 public:
D
dangqingqing 已提交
304
  MultiClassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker)
305
      : OpProtoAndCheckerMaker(proto, op_checker) {
D
dangqingqing 已提交
306 307 308 309 310
    AddInput("BBoxes",
             "(Tensor) A 2-D Tensor with shape [M, 4] represents the "
             "predicted locations of M bounding bboxes. Each bounding box "
             "has four coordinate values and the layout is "
             "[xmin, ymin, xmax, ymax].");
D
dangqingqing 已提交
311 312
    AddInput("Scores",
             "(Tensor) A 3-D Tensor with shape [N, C, M] represents the "
D
dangqingqing 已提交
313 314 315 316
             "predicted confidence predictions. N is the batch size, C is the "
             "class number, M is number of bounding boxes. For each category "
             "there are total M scores which corresponding M bounding boxes. "
             " Please note, M is equal to the 1st dimension of BBoxes. ");
D
dangqingqing 已提交
317
    AddAttr<int>(
318 319
        "background_label",
        "(int64_t, defalut: 0) "
D
dangqingqing 已提交
320 321
        "The index of background label, the background label will be ignored. "
        "If set to -1, then all categories will be considered.")
322
        .SetDefault(0);
D
dangqingqing 已提交
323 324
    AddAttr<float>("score_threshold",
                   "(float) "
D
dangqingqing 已提交
325 326
                   "Threshold to filter out bounding boxes with low "
                   "confidence score. If not provided, consider all boxes.");
D
dangqingqing 已提交
327 328 329 330 331
    AddAttr<int>("nms_top_k",
                 "(int64_t) "
                 "Maximum number of detections to be kept according to the "
                 "confidences aftern the filtering detections based on "
                 "score_threshold");
332 333
    AddAttr<float>("nms_threshold",
                   "(float, defalut: 0.3) "
D
dangqingqing 已提交
334
                   "The threshold to be used in NMS.")
335 336 337
        .SetDefault(0.3);
    AddAttr<float>("nms_eta",
                   "(float) "
D
dangqingqing 已提交
338
                   "The parameter for adaptive NMS.")
339
        .SetDefault(1.0);
D
dangqingqing 已提交
340 341 342 343
    AddAttr<int>("keep_top_k",
                 "(int64_t) "
                 "Number of total bboxes to be kept per image after NMS "
                 "step. -1 means keeping all bboxes after NMS step.");
344 345 346 347 348 349 350 351 352
    AddOutput("Out",
              "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the "
              "detections. Each row has 6 values: "
              "[label, confidence, xmin, ymin, xmax, ymax], No is the total "
              "number of detections in this mini-batch. For each instance, "
              "the offsets in first dimension are called LoD, the number of "
              "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
              "no detected bbox.");
    AddComment(R"DOC(
D
dangqingqing 已提交
353
This operator is to do multi-class non maximum suppression (NMS) on a batched
354 355
of boxes and scores.

D
dangqingqing 已提交
356 357 358 359 360 361 362
In the NMS step, this operator greedily selects a subset of detection bounding
boxes that have high scores larger than score_threshold, if providing this
threshold, then selects the largest nms_top_k confidences scores if nms_top_k
is larger than -1. Then this operator pruns away boxes that have high IOU
(intersection over union) overlap with already selected boxes by adaptive
threshold NMS based on parameters of nms_threshold and nms_eta.

363
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
D
dangqingqing 已提交
364
per image if keep_top_k is larger than -1.
365

D
dangqingqing 已提交
366
This operator support multi-class and batched inputs. It applying NMS
367 368 369 370 371 372
independently for each class. The outputs is a 2-D LoDTenosr, for each
image, the offsets in first dimension of LoDTensor are called LoD, the number
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
means there is no detected bbox for this image. If there is no detected boxes
for all images, all the elements in LoD are 0, and the Out only contains one
value which is -1.
373 374 375 376 377 378 379 380
)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
D
dangqingqing 已提交
381 382
REGISTER_OPERATOR(multiclass_nms, ops::MultiClassNMSOp,
                  ops::MultiClassNMSOpMaker,
383
                  paddle::framework::EmptyGradOpMaker);
D
dangqingqing 已提交
384 385
REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MultiClassNMSKernel<float>,
                       ops::MultiClassNMSKernel<double>);