From 912a4f2511ad118d7a989cbe4e7f634503670e34 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 29 Jan 2018 23:49:56 +0800 Subject: [PATCH] Add multi-class non-maximum suppression operator. --- paddle/operators/multiclass_nms_op.cc | 353 ++++++++++++++++++ .../v2/fluid/tests/test_bipartite_match_op.py | 2 +- .../v2/fluid/tests/test_multiclass_nms_op.py | 199 ++++++++++ 3 files changed, 553 insertions(+), 1 deletion(-) create mode 100644 paddle/operators/multiclass_nms_op.cc create mode 100644 python/paddle/v2/fluid/tests/test_multiclass_nms_op.py diff --git a/paddle/operators/multiclass_nms_op.cc b/paddle/operators/multiclass_nms_op.cc new file mode 100644 index 000000000..19c5b7efd --- /dev/null +++ b/paddle/operators/multiclass_nms_op.cc @@ -0,0 +1,353 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +constexpr int64_t kOutputDim = 6; +constexpr int64_t kBBoxSize = 4; + +class MulticlassNMSOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Bboxes"), + "Input(Bboxes) of MulticlassNMS should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Scores"), + "Input(Scores) of MulticlassNMS should not be null."); + + auto box_dims = ctx->GetInputDim("Bboxes"); + auto score_dims = ctx->GetInputDim("Scores"); + + PADDLE_ENFORCE_EQ(box_dims.size(), 3, + "The rank of Input(Bboxes) must be 3."); + PADDLE_ENFORCE_EQ(score_dims.size(), 3, + "The rank of Input(Scores) must be 3."); + PADDLE_ENFORCE_EQ(box_dims[0], score_dims[0]); + PADDLE_ENFORCE_EQ(box_dims[2], 4); + PADDLE_ENFORCE_EQ(box_dims[1], score_dims[2]); + + // Here the box_dims[0] is not the real dimension of output. + // It will be rewritten in the computing kernel. + ctx->SetOutputDim("Out", {box_dims[0], 6}); + } +}; + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +static inline void GetMaxScoreIndex( + const std::vector& scores, const T threshold, int top_k, + std::vector>* sorted_indices) { + for (size_t i = 0; i < scores.size(); ++i) { + if (scores[i] > threshold) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), sorted_indices->end(), + SortScorePairDescend); + // Keep top_k scores if needed. + if (top_k > -1 && top_k < sorted_indices->size()) { + sorted_indices->resize(top_k); + } +} + +template +T BBoxArea(const T* box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. + return T(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If bbox is not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static inline T JaccardOverlap(const T* box1, const T* box2, + const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + const T inter_w = inter_xmax - inter_xmin; + const T inter_h = inter_ymax - inter_ymin; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +class MulticlassNMSKernel : public framework::OpKernel { + public: + void NMSFast(const Tensor& bbox, const Tensor& scores, + const T score_threshold, const T nms_threshold, const T eta, + const int64_t top_k, std::vector* selected_indices) const { + // The total boxes for each instance. + int64_t num_boxes = bbox.dims()[0]; + // 4: [xmin ymin xmax ymax] + int64_t box_size = bbox.dims()[1]; + + std::vector scores_data(num_boxes); + std::copy_n(scores.data(), num_boxes, scores_data.begin()); + std::vector> sorted_indices; + GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices); + + selected_indices->clear(); + T adaptive_threshold = nms_threshold; + const T* bbox_data = bbox.data(); + + while (sorted_indices.size() != 0) { + const int idx = sorted_indices.front().second; + bool keep = true; + for (int k = 0; k < selected_indices->size(); ++k) { + if (keep) { + const int kept_idx = (*selected_indices)[k]; + T overlap = JaccardOverlap(bbox_data + idx * box_size, + bbox_data + kept_idx * box_size, true); + keep = overlap <= adaptive_threshold; + } else { + break; + } + } + if (keep) { + selected_indices->push_back(idx); + } + sorted_indices.erase(sorted_indices.begin()); + if (keep && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } + } + + void MulticlassNMS(const framework::ExecutionContext& ctx, + const Tensor& scores, const Tensor& bboxes, + std::map>* indices, + int* num_nmsed_out) const { + int64_t background_label = ctx.Attr("background_label"); + int64_t nms_top_k = ctx.Attr("nms_top_k"); + int64_t keep_top_k = ctx.Attr("keep_top_k"); + T nms_threshold = static_cast(ctx.Attr("nms_threshold")); + T nms_eta = static_cast(ctx.Attr("nms_eta")); + T score_threshold = static_cast(ctx.Attr("confidence_threshold")); + + int64_t class_num = scores.dims()[0]; + int64_t predict_dim = scores.dims()[1]; + int num_det = 0; + for (int64_t c = 0; c < class_num; ++c) { + if (c == background_label) continue; + Tensor score = scores.Slice(c, c + 1); + NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k, + &((*indices)[c])); + num_det += indices[c].size(); + } + + *num_nmsed_out = num_det; + const T* scores_data = scores.data(); + if (keep_top_k > -1 && num_det > keep_top_k) { + std::vector>> score_index_pairs; + for (const auto& it : *indices) { + int label = it.first; + const T* sdata = scores_data + label * predict_dim; + const std::vector& label_indices = it.second; + for (int j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + PADDLE_ENFORCE_LT(idx, predict_dim); + score_index_pairs.push_back( + std::make_pair(sdata[idx], std::make_pair(label, idx))); + } + } + // Keep top k results per image. + std::sort(score_index_pairs.begin(), score_index_pairs.end(), + SortScorePairDescend>); + score_index_pairs.resize(keep_top_k); + + // Store the new indices. + std::map> new_indices; + for (int j = 0; j < score_index_pairs.size(); ++j) { + int label = score_index_pairs[j].second.first; + int idx = score_index_pairs[j].second.second; + new_indices[label].push_back(idx); + } + new_indices.swap(*indices); + *num_nmsed_out = keep_top_k; + } + } + + void MulticlassOutput(const Tensor& scores, const Tensor& bboxes, + std::map>& selected_indices, + Tensor* outs) const { + int predict_dim = scores.dims()[1]; + auto* scores_data = scores.data(); + auto* bboxes_data = bboxes.data(); + auto* odata = outs->data(); + + int count = 0; + for (const auto& it : selected_indices) { + int label = it.first; + const T* sdata = scores_data + label * predict_dim; + std::vector indices = it.second; + for (int j = 0; j < indices.size(); ++j) { + int idx = indices[j]; + const T* bdata = bboxes_data + idx * kBBoxSize; + odata[count * kOutputDim] = label; // label + odata[count * kOutputDim + 1] = sdata[idx]; // score + odata[count * kOutputDim + 2] = bdata[0]; // xmin + odata[count * kOutputDim + 3] = bdata[1]; // ymin + odata[count * kOutputDim + 4] = bdata[2]; // xmax + odata[count * kOutputDim + 5] = bdata[3]; // ymax + } + count++; + } + } + + void Compute(const framework::ExecutionContext& ctx) const override { + auto* boxes = ctx.Input("Bboxes"); + auto* scores = ctx.Input("Scores"); + auto* outs = ctx.Output("Out"); + + auto box_dims = boxes->dims(); + auto score_dims = scores->dims(); + + int64_t batch_size = box_dims[0]; + int64_t class_num = score_dims[1]; + int64_t predict_dim = score_dims[2]; + + std::vector>> all_indices; + std::vector batch_starts = {0}; + for (int64_t i = 0; i < batch_size; ++i) { + Tensor ins_score = scores->Slice(i, i + 1); + ins_score.Resize({class_num, predict_dim}); + std::map> indices; + int num_nmsed_out = 0; + MulticlassNMS(ctx, ins_score, *boxes, &indices, &num_nmsed_out); + all_indices.push_back(indices); + batch_starts.push_back(batch_starts.back() + num_nmsed_out); + } + + int num_kept = batch_starts.back(); + if (num_kept == 0) { + outs->Resize({0, 0}); + } else { + outs->mutable_data({num_kept, kOutputDim}, ctx.GetPlace()); + for (int64_t i = 0; i < batch_size; ++i) { + Tensor ins_score = scores->Slice(i, i + 1); + ins_score.Resize({class_num, predict_dim}); + int64_t s = batch_starts[i]; + int64_t e = batch_starts[i + 1]; + if (e > s) { + Tensor out = outs->Slice(s, e); + MulticlassOutput(ins_score, *boxes, all_indices[i], &out); + } + } + } + + framework::LoD lod; + lod.emplace_back(batch_starts); + + outs->set_lod(lod); + } +}; + +class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MulticlassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Bboxes", + "(Tensor) A 2-D Tensor with shape [M, 4] represents the location " + "predictions with M bboxes. 4 is the number of " + "each location coordinates."); + AddOutput("Scores", + "(Tensor) A 3-D Tensor with shape [N, C, M] represents the " + "confidence predictions. N is the batch size, C is the class " + "number, M is number of predictions for each class, which is " + "the same with Bboxes."); + AddAttr( + "background_label", + "(int64_t, defalut: 0) " + "The index of background label, the background label will be ignored.") + .SetDefault(0); + AddAttr("nms_threshold", + "(float, defalut: 0.3) " + "The threshold to be used in nms.") + .SetDefault(0.3); + AddAttr("nms_top_k", + "(int64_t) " + " ."); + AddAttr("nms_eta", + "(float) " + "The parameter for adaptive nms.") + .SetDefault(1.0); + AddAttr("keep_top_k", + "(int64_t) " + "."); + AddAttr("confidence_threshold", + "(float) " + "."); + AddOutput("Out", + "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " + "detections. Each row has 6 values: " + "[label, confidence, xmin, ymin, xmax, ymax], No is the total " + "number of detections in this mini-batch. For each instance, " + "the offsets in first dimension are called LoD, the number of " + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " + "no detected bbox."); + AddComment(R"DOC( +This operators is to do multi-class non maximum suppression (nms) on a batched +of boxes and scores. + +This op greedily selects a subset of detection bounding boxes, pruning +away boxes that have high IOU (intersection over union) overlap (> thresh) +with already selected boxes. It operates independently for each class for +which scores are provided (via the scores field of the input box_list), +pruning boxes with score less than a provided threshold prior to +applying NMS. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(multiclass_nms, ops::MulticlassNMSOp, + ops::MulticlassNMSOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MulticlassNMSKernel, + ops::MulticlassNMSKernel); diff --git a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py index 741382989..c35fb20b1 100644 --- a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py +++ b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py @@ -62,7 +62,7 @@ def batch_bipartite_match(distance, lod): return match_indices, match_dist -class TestBipartiteMatchOpForWithLoD(OpTest): +class TestBipartiteMatchOpWithLoD(OpTest): def setUp(self): self.op_type = 'bipartite_match' lod = [[0, 5, 11, 23]] diff --git a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py new file mode 100644 index 000000000..60c6488f8 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py @@ -0,0 +1,199 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +import copy +from op_test import OpTest + + +def iou(box_a, box_b): + """Apply intersection-over-union overlap between box_a and box_b + """ + xmin_a = min(box_a[0], box_a[2]) + ymin_a = min(box_a[1], box_a[3]) + xmax_a = max(box_a[0], box_a[2]) + ymax_a = max(box_a[1], box_a[3]) + + xmin_b = min(box_b[0], box_b[2]) + ymin_b = min(box_b[1], box_b[3]) + xmax_b = max(box_b[0], box_b[2]) + ymax_b = max(box_b[1], box_b[3]) + + area_a = (ymax_a - ymin_a) * (xmax_a - xmin_a) + area_b = (ymax_b - ymin_b) * (xmax_b - xmin_b) + if area_a <= 0 and area_b <= 0: + return 0.0 + + xa = max(xmin_a, xmin_b) + ya = max(ymin_a, ymin_b) + xb = min(xmax_a, xmax_b) + yb = min(ymax_a, ymax_b) + + inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0) + + box_a_area = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]) + box_b_area = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) + + iou_ratio = inter_area / (area_a + area_b - inter_area) + + return iou_ratio + + +def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + all_scores = copy.deepcopy(scores) + all_scores = all_scores.flatten() + selected_indices = np.argwhere(all_scores > score_threshold) + selected_indices = selected_indices.flatten() + all_scores = all_scores[selected_indices] + + sorted_indices = np.argsort(-all_scores, axis=0) + sorted_scores = all_scores[sorted_indices] + if top_k < -1 and top_k < sorted_indices.shape[0]: + sorted_indices = sorted_indices[:top_k] + sorted_scores = sorted_scores[:top_k] + + selected_indices = [] + adaptive_threshold = nms_threshold + for i in range(sorted_scores.shape[0]): + idx = sorted_indices[i] + keep = True + for k in range(len(selected_indices)): + if keep: + kept_idx = selected_indices[k] + overlap = iou(boxes[idx], boxes[kept_idx]) + keep = overlap <= adaptive_threshold + else: + break + if keep: + selected_indices.append(idx) + if keep and eta < 1 and adaptive_threshold > 0.5: + adaptive_threshold *= eta + return selected_indices + + +def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, + nms_top_k, keep_top_k): + class_num = scores.shape[0] + priorbox_num = scores.shape[1] + + selected_indices = [] + num_det = 0 + for c in range(class_num): + if c == background: continue + indices = nms(boxes, scores[c], score_threshold, nms_threshold, + nms_top_k) + selected_indices.append((c, indices)) + num_det += len(indices) + + if keep_top_k > -1 and num_det > keep_top_k: + score_index = [] + for c, indices in selected_indices: + for idx in indices: + score_index.append((scores[c][idx], c, idx)) + + sorted_score_index = sorted( + score_index, key=lambda tup: tup[0], reverse=True) + sorted_score_index = sorted_score_index[:keep_top_k] + selected_indices = [] + for s, c, idx in sorted_score_index: + selected_indices.append((c, idx)) + + return selected_indices + + +def batched_multiclass_nms(boxes, scores, background, score_threshold, + nms_threshold, nms_top_k, keep_top_k): + batch_size = scores.shape[0] + + det_outs = [] + lod = [0] + for n in range(batch_size): + nmsed_outs = multiclass_nms(boxes, scores[n], background, + score_threshold, nms_threshold, nms_top_k, + keep_top_k) + lod.append(lod[-1] + len(nmsed_outs)) + if len(nmsed_outs) == 0: continue + for c, indices in nmsed_outs: + for idx in indices: + xmin, ymin, xmax, ymax = boxes[idx][:] + det_outs.append( + (c, scores[n][c][idx], c, xmin, ymin, xmax, ymax)) + return det_outs, lod + + +class TestMulticlassNMSOp(OpTest): + def setUp(self): + self.op_type = 'multiclass_nms' + N = 7 + M = 1230 + C = 21 + BOX_SIZE = 4 + background = 0 + nms_threshold = 0.3 + nms_top_k = 400 + keep_top_k = 200 + score_threshold = 0.01 + + scores = np.random.random((N, C, M)).astype('float32') + boxes = np.random.random((M, BOX_SIZE)).astype('float32') + boxes[:, 0:2] = boxes[:, 0:2] * 0.5 + boxes[:, 2:4] = boxes[:, 0:2] * 0.5 + 0.5 + + nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background, + score_threshold, nms_threshold, + nms_top_k, keep_top_k) + self.inputs = {'Bboxes': boxes, 'Scores': scores} + self.outputs = {'Out': (nmsed_outs, [lod])} + + def test_check_output(self): + self.check_output() + + +class TestIOU(unittest.TestCase): + def test_iou(self): + box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32') + box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32') + + expt_output = np.array([2.0 / 16.0]).astype('float32') + calc_output = np.array([iou(box1, box2)]).astype('float32') + self.assertTrue(np.allclose(calc_output, expt_output)) + + +if __name__ == '__main__': + unittest.main() + # N = 7 + # M = 8 + # C = 5 + # BOX_SIZE = 4 + # background = 0 + # nms_threshold = 0.3 + # nms_top_k = 400 + # keep_top_k = 200 + # score_threshold = 0.5 + + # scores = np.random.random((N, C, M)).astype('float32') + # boxes = np.random.random((M, BOX_SIZE)).astype('float32') + # boxes[:, 0 : 2] = boxes[:, 0 : 2] * 0.5 + # boxes[:, 2 : 4] = boxes[:, 0 : 2] * 0.5 + 0.5 + # print nmsed_outs, lod -- GitLab