From 35dec3d7228e2f924ccc6549a420604110640337 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 30 Jan 2018 17:59:48 +0800 Subject: [PATCH] Fix bug in unit test. --- paddle/operators/multiclass_nms_op.cc | 84 +++++++++++-------- .../v2/fluid/tests/test_multiclass_nms_op.py | 61 +++++++------- 2 files changed, 82 insertions(+), 63 deletions(-) diff --git a/paddle/operators/multiclass_nms_op.cc b/paddle/operators/multiclass_nms_op.cc index 5da553a6cc2..93c8b5216f6 100644 --- a/paddle/operators/multiclass_nms_op.cc +++ b/paddle/operators/multiclass_nms_op.cc @@ -41,13 +41,22 @@ class MulticlassNMSOp : public framework::OperatorWithKernel { "The rank of Input(Bboxes) must be 3."); PADDLE_ENFORCE_EQ(score_dims.size(), 3, "The rank of Input(Scores) must be 3."); - PADDLE_ENFORCE_EQ(box_dims[2], 4); + PADDLE_ENFORCE_EQ(box_dims[1], 4); PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2]); // Here the box_dims[0] is not the real dimension of output. // It will be rewritten in the computing kernel. ctx->SetOutputDim("Out", {box_dims[0], 6}); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.Input("Scores")->type()), + ctx.device_context()); + } }; template @@ -158,12 +167,12 @@ class MulticlassNMSKernel : public framework::OpKernel { const Tensor& scores, const Tensor& bboxes, std::map>* indices, int* num_nmsed_out) const { - int64_t background_label = ctx.Attr("background_label"); - int64_t nms_top_k = ctx.Attr("nms_top_k"); - int64_t keep_top_k = ctx.Attr("keep_top_k"); + int64_t background_label = ctx.Attr("background_label"); + int64_t nms_top_k = ctx.Attr("nms_top_k"); + int64_t keep_top_k = ctx.Attr("keep_top_k"); T nms_threshold = static_cast(ctx.Attr("nms_threshold")); T nms_eta = static_cast(ctx.Attr("nms_eta")); - T score_threshold = static_cast(ctx.Attr("confidence_threshold")); + T score_threshold = static_cast(ctx.Attr("score_threshold")); int64_t class_num = scores.dims()[0]; int64_t predict_dim = scores.dims()[1]; @@ -173,7 +182,7 @@ class MulticlassNMSKernel : public framework::OpKernel { Tensor score = scores.Slice(c, c + 1); NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k, &((*indices)[c])); - num_det += indices[c].size(); + num_det += (*indices)[c].size(); } *num_nmsed_out = num_det; @@ -230,8 +239,8 @@ class MulticlassNMSKernel : public framework::OpKernel { odata[count * kOutputDim + 3] = bdata[1]; // ymin odata[count * kOutputDim + 4] = bdata[2]; // xmax odata[count * kOutputDim + 5] = bdata[3]; // ymax + count++; } - count++; } } @@ -240,10 +249,9 @@ class MulticlassNMSKernel : public framework::OpKernel { auto* scores = ctx.Input("Scores"); auto* outs = ctx.Output("Out"); - auto box_dims = boxes->dims(); auto score_dims = scores->dims(); - int64_t batch_size = box_dims[0]; + int64_t batch_size = score_dims[0]; int64_t class_num = score_dims[1]; int64_t predict_dim = score_dims[2]; @@ -291,35 +299,37 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) A 2-D Tensor with shape [M, 4] represents the location " "predictions with M bboxes. 4 is the number of " "each location coordinates."); - AddOutput("Scores", - "(Tensor) A 3-D Tensor with shape [N, C, M] represents the " - "confidence predictions. N is the batch size, C is the class " - "number, M is number of predictions for each class, which is " - "the same with Bboxes."); - AddAttr( + AddInput("Scores", + "(Tensor) A 3-D Tensor with shape [N, C, M] represents the " + "confidence predictions. N is the batch size, C is the class " + "number, M is number of predictions for each class, which is " + "the same with Bboxes."); + AddAttr( "background_label", "(int64_t, defalut: 0) " "The index of background label, the background label will be ignored.") .SetDefault(0); + AddAttr("score_threshold", + "(float) " + "Only consider detections whose confidences are larger than " + "a threshold. If not provided, consider all boxes."); + AddAttr("nms_top_k", + "(int64_t) " + "Maximum number of detections to be kept according to the " + "confidences aftern the filtering detections based on " + "score_threshold"); AddAttr("nms_threshold", "(float, defalut: 0.3) " - "The threshold to be used in nms.") + "The threshold to be used in NMS.") .SetDefault(0.3); - AddAttr("nms_top_k", - "(int64_t) " - "Maximum number of results to be kept."); AddAttr("nms_eta", "(float) " - "The parameter for adaptive nms.") + "The parameter for adaptive NMS.") .SetDefault(1.0); - AddAttr("keep_top_k", - "(int64_t) " - "Number of total bboxes to be kept per image after nms " - "step. -1 means keeping all bboxes after nms step."); - AddAttr("confidence_threshold", - "(float) " - "Only consider detections whose confidences are larger than " - "a threshold. If not provided, consider all boxes."); + AddAttr("keep_top_k", + "(int64_t) " + "Number of total bboxes to be kept per image after NMS " + "step. -1 means keeping all bboxes after NMS step."); AddOutput("Out", "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " "detections. Each row has 6 values: " @@ -329,15 +339,21 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " "no detected bbox."); AddComment(R"DOC( -This operators is to do multi-class non maximum suppression (NMS) on a batched +This operator is to do multi-class non maximum suppression (NMS) on a batched of boxes and scores. -This op greedily selects a subset of detection bounding boxes, pruning -away boxes that have high IOU (intersection over union) overlap (> thresh) -with already selected boxes. It operates independently for each class for -which scores are provided, pruning boxes with score less than a provided -threshold prior to applying NMS. +In the NMS step, this operator greedily selects a subset of detection bounding +boxes that have high scores larger than score_threshold, if providing this +threshold, then selects the largest nms_top_k confidences scores if nms_top_k +is larger than -1. Then this operator pruns away boxes that have high IOU +(intersection over union) overlap with already selected boxes by adaptive +threshold NMS based on parameters of nms_threshold and nms_eta. + +Aftern NMS step, only at most keep_top_k number of total bboxes are to be kept +per image if keep_top_k is larger than -1. +This operator support multi-class and batched inputs. It applying NMS +independently for each class. )DOC"); } }; diff --git a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py index 60c6488f84f..b619c52e550 100644 --- a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py +++ b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py @@ -69,7 +69,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): sorted_indices = np.argsort(-all_scores, axis=0) sorted_scores = all_scores[sorted_indices] - if top_k < -1 and top_k < sorted_indices.shape[0]: + if top_k > -1 and top_k < sorted_indices.shape[0]: sorted_indices = sorted_indices[:top_k] sorted_scores = sorted_scores[:top_k] @@ -82,7 +82,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): if keep: kept_idx = selected_indices[k] overlap = iou(boxes[idx], boxes[kept_idx]) - keep = overlap <= adaptive_threshold + keep = True if overlap <= adaptive_threshold else False else: break if keep: @@ -103,14 +103,14 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, if c == background: continue indices = nms(boxes, scores[c], score_threshold, nms_threshold, nms_top_k) - selected_indices.append((c, indices)) + for idx in indices: + selected_indices.append((c, idx)) num_det += len(indices) if keep_top_k > -1 and num_det > keep_top_k: score_index = [] - for c, indices in selected_indices: - for idx in indices: - score_index.append((scores[c][idx], c, idx)) + for c, idx in selected_indices: + score_index.append((scores[c][idx], c, idx)) sorted_score_index = sorted( score_index, key=lambda tup: tup[0], reverse=True) @@ -134,19 +134,16 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold, keep_top_k) lod.append(lod[-1] + len(nmsed_outs)) if len(nmsed_outs) == 0: continue - for c, indices in nmsed_outs: - for idx in indices: - xmin, ymin, xmax, ymax = boxes[idx][:] - det_outs.append( - (c, scores[n][c][idx], c, xmin, ymin, xmax, ymax)) + for c, idx in nmsed_outs: + xmin, ymin, xmax, ymax = boxes[idx][:] + det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax]) return det_outs, lod class TestMulticlassNMSOp(OpTest): def setUp(self): - self.op_type = 'multiclass_nms' N = 7 - M = 1230 + M = 1240 C = 21 BOX_SIZE = 4 background = 0 @@ -155,7 +152,17 @@ class TestMulticlassNMSOp(OpTest): keep_top_k = 200 score_threshold = 0.01 - scores = np.random.random((N, C, M)).astype('float32') + scores = np.random.random((N * M, C)).astype('float32') + + def softmax(x): + shiftx = x - np.max(x).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + scores = np.apply_along_axis(softmax, 1, scores) + scores = np.reshape(scores, (N, M, C)) + scores = np.transpose(scores, (0, 2, 1)) + boxes = np.random.random((M, BOX_SIZE)).astype('float32') boxes[:, 0:2] = boxes[:, 0:2] * 0.5 boxes[:, 2:4] = boxes[:, 0:2] * 0.5 + 0.5 @@ -163,8 +170,19 @@ class TestMulticlassNMSOp(OpTest): nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, nms_top_k, keep_top_k) + nmsed_outs = np.array(nmsed_outs).astype('float32') + + self.op_type = 'multiclass_nms' self.inputs = {'Bboxes': boxes, 'Scores': scores} self.outputs = {'Out': (nmsed_outs, [lod])} + self.attrs = { + 'background_label': 0, + 'nms_threshold': nms_threshold, + 'nms_top_k': nms_top_k, + 'keep_top_k': keep_top_k, + 'score_threshold': score_threshold, + 'nms_eta': 1.0, + } def test_check_output(self): self.check_output() @@ -182,18 +200,3 @@ class TestIOU(unittest.TestCase): if __name__ == '__main__': unittest.main() - # N = 7 - # M = 8 - # C = 5 - # BOX_SIZE = 4 - # background = 0 - # nms_threshold = 0.3 - # nms_top_k = 400 - # keep_top_k = 200 - # score_threshold = 0.5 - - # scores = np.random.random((N, C, M)).astype('float32') - # boxes = np.random.random((M, BOX_SIZE)).astype('float32') - # boxes[:, 0 : 2] = boxes[:, 0 : 2] * 0.5 - # boxes[:, 2 : 4] = boxes[:, 0 : 2] * 0.5 + 0.5 - # print nmsed_outs, lod -- GitLab