From f660553d7781c065ef61d09ca136373d7c983f0f Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Fri, 18 Jan 2019 08:41:27 +0000 Subject: [PATCH] enhance nms for mask rcnn, test=develop --- paddle/fluid/operators/detection/multiclass_nms_op.cc | 3 +-- .../fluid/tests/unittests/test_multiclass_nms_op.py | 10 ++++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 680754dded3..14ce9937dc6 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -405,7 +405,7 @@ class MultiClassNMSKernel : public framework::OpKernel { if (num_kept == 0) { T* od = outs->mutable_data({1, 1}, ctx.GetPlace()); od[0] = -1; - batch_starts.back() = 1; + batch_starts = {0, 1}; } else { outs->mutable_data({num_kept, out_dim}, ctx.GetPlace()); if (score_dims.size() == 3) { @@ -443,7 +443,6 @@ class MultiClassNMSKernel : public framework::OpKernel { framework::LoD lod; lod.emplace_back(batch_starts); - LOG(ERROR) << "c++ lod: " << lod; outs->set_lod(lod); } diff --git a/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py b/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py index af36bcfaa08..2a50e0bd856 100644 --- a/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py @@ -173,13 +173,15 @@ def lod_multiclass_nms(boxes, scores, background, score_threshold, normalized, shared=False) if nmsed_num == 0: - lod.append(1) + #lod.append(1) continue lod.append(nmsed_num) for c, indices in nmsed_outs.items(): for idx in indices: xmin, ymin, xmax, ymax = box[idx, c, :] det_outs.append([c, score[idx][c], xmin, ymin, xmax, ymax]) + if len(lod) == 0: + lod.append(1) return det_outs, lod @@ -208,7 +210,7 @@ def batched_multiclass_nms(boxes, normalized, shared=True) if nmsed_num == 0: - lod.append(1) + # lod.append(1) continue lod.append(nmsed_num) @@ -221,7 +223,8 @@ def batched_multiclass_nms(boxes, sorted_det_out = sorted( tmp_det_out, key=lambda tup: tup[0], reverse=False) det_outs.extend(sorted_det_out) - + if len(lod) == 0: + lod += [1] return det_outs, lod @@ -259,7 +262,6 @@ class TestMulticlassNMSOp(OpTest): nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, nms_top_k, keep_top_k) - print('python lod: ', lod) nmsed_outs = [-1] if not nmsed_outs else nmsed_outs nmsed_outs = np.array(nmsed_outs).astype('float32') -- GitLab