From f06c6193d709a4e04d2f7e111a3026de95022bce Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Tue, 23 Oct 2018 01:46:09 +0000 Subject: [PATCH] fix rpn target assign test=develop --- .../detection/rpn_target_assign_op.cc | 68 ++++++++++++++----- python/paddle/fluid/layers/detection.py | 15 ++-- python/paddle/fluid/tests/test_detection.py | 6 +- .../unittests/test_rpn_target_assign_op.py | 48 +++++++++---- 4 files changed, 100 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index dda423efd35..63895f8a1d5 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -52,6 +52,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->HasOutput("TargetBBox"), "Output(TargetBBox) of RpnTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("BBox_inside_weight"), + "Output(BBox_inside_weight) of RpnTargetAssignOp should not be null"); auto anchor_dims = ctx->GetInputDim("Anchor"); auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); @@ -68,6 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ScoreIndex", {-1}); ctx->SetOutputDim("TargetLabel", {-1, 1}); ctx->SetOutputDim("TargetBBox", {-1, 4}); + ctx->SetOutputDim("BBox_inside_weight", {-1, 4}); } protected: @@ -169,6 +173,7 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, const float rpn_positive_overlap, const float rpn_negative_overlap, std::vector* fg_inds, std::vector* bg_inds, std::vector* tgt_lbl, + std::vector* fg_fake, std::vector* bbox_inside_weight, std::minstd_rand engine, bool use_random) { float epsilon = 0.00001; int anchor_num = anchor_to_gt_max.dims()[0]; @@ -201,12 +206,12 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, // Reservoir Sampling int fg_num = static_cast(rpn_fg_fraction * rpn_batch_size_per_im); ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random); - fg_num = static_cast(fg_inds_fake.size()); - for (int64_t i = 0; i < fg_num; ++i) { + int fg_fake_num = static_cast(fg_inds_fake.size()); + for (int64_t i = 0; i < fg_fake_num; ++i) { target_label[fg_inds_fake[i]] = 1; } - int bg_num = rpn_batch_size_per_im - fg_num; + int bg_num = rpn_batch_size_per_im - fg_fake_num; for (int64_t i = 0; i < anchor_num; ++i) { if (anchor_to_gt_max_data[i] < rpn_negative_overlap) { bg_inds_fake.push_back(i); @@ -214,12 +219,28 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, } ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random); bg_num = static_cast(bg_inds_fake.size()); + int fake_num = 0; for (int64_t i = 0; i < bg_num; ++i) { + // fg fake found + if (target_label[bg_inds_fake[i]] == 1) { + fake_num++; + fg_fake->emplace_back(fg_inds_fake[0]); + for (int j = 0; j < 4; ++j) { + bbox_inside_weight->emplace_back(T(0.)); + } + } target_label[bg_inds_fake[i]] = 0; } + for (int64_t i = 0; i < (fg_fake_num - fake_num) * 4; ++i) { + bbox_inside_weight->emplace_back(T(1.)); + } + for (int64_t i = 0; i < anchor_num; ++i) { - if (target_label[i] == 1) fg_inds->emplace_back(i); + if (target_label[i] == 1) { + fg_inds->emplace_back(i); + fg_fake->emplace_back(i); + } if (target_label[i] == 0) bg_inds->emplace_back(i); } fg_num = fg_inds->size(); @@ -248,7 +269,8 @@ std::vector SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx, std::vector bg_inds; std::vector gt_inds; std::vector tgt_lbl; - + std::vector fg_fake; + std::vector bbox_inside_weight; // Calculate the max IoU between anchors and gt boxes // Map from anchor to gt box that has highest overlap auto place = ctx.GetPlace(); @@ -275,32 +297,37 @@ std::vector SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx, // Follow the Faster RCNN's implementation ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max, rpn_batch_size_per_im, rpn_fg_fraction, rpn_positive_overlap, - rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, engine, - use_random); + rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, &fg_fake, + &bbox_inside_weight, engine, use_random); int fg_num = fg_inds.size(); int bg_num = bg_inds.size(); - gt_inds.reserve(fg_num); - for (int i = 0; i < fg_num; ++i) { - gt_inds.emplace_back(argmax[fg_inds[i]]); + int fg_fake_num = fg_fake.size(); + gt_inds.reserve(fg_fake_num); + for (int i = 0; i < fg_fake_num; ++i) { + gt_inds.emplace_back(argmax[fg_fake[i]]); } - - Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t; - int* loc_index_data = loc_index_t.mutable_data({fg_num}, place); + Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t, bbox_inside_weight_t; + int* loc_index_data = loc_index_t.mutable_data({fg_fake_num}, place); int* score_index_data = score_index_t.mutable_data({fg_num + bg_num}, place); int* tgt_lbl_data = tgt_lbl_t.mutable_data({fg_num + bg_num}, place); - int* gt_inds_data = gt_inds_t.mutable_data({fg_num}, place); - std::copy(fg_inds.begin(), fg_inds.end(), loc_index_data); + int* gt_inds_data = gt_inds_t.mutable_data({fg_fake_num}, place); + T* bbox_inside_weight_data = + bbox_inside_weight_t.mutable_data({fg_fake_num, 4}, place); + std::copy(fg_fake.begin(), fg_fake.end(), loc_index_data); std::copy(fg_inds.begin(), fg_inds.end(), score_index_data); std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num); std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data); std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data); + std::copy(bbox_inside_weight.begin(), bbox_inside_weight.end(), + bbox_inside_weight_data); std::vector loc_score_tgtlbl_gt; loc_score_tgtlbl_gt.emplace_back(loc_index_t); loc_score_tgtlbl_gt.emplace_back(score_index_t); loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t); loc_score_tgtlbl_gt.emplace_back(gt_inds_t); + loc_score_tgtlbl_gt.emplace_back(bbox_inside_weight_t); return loc_score_tgtlbl_gt; } @@ -318,6 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel { auto* score_index = context.Output("ScoreIndex"); auto* tgt_bbox = context.Output("TargetBBox"); auto* tgt_lbl = context.Output("TargetLabel"); + auto* bbox_inside_weight = context.Output("BBox_inside_weight"); PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL, "RpnTargetAssignOp gt_boxes needs 1 level of LoD"); @@ -340,7 +368,7 @@ class RpnTargetAssignKernel : public framework::OpKernel { score_index->mutable_data({max_num}, place); tgt_bbox->mutable_data({max_num, 4}, place); tgt_lbl->mutable_data({max_num, 1}, place); - + bbox_inside_weight->mutable_data({max_num, 4}, place); auto& dev_ctx = context.device_context(); std::random_device rnd; @@ -394,6 +422,7 @@ class RpnTargetAssignKernel : public framework::OpKernel { Tensor sampled_score_index = loc_score_tgtlbl_gt[1]; Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2]; Tensor sampled_gt_index = loc_score_tgtlbl_gt[3]; + Tensor sampled_bbox_inside_weight = loc_score_tgtlbl_gt[4]; int loc_num = sampled_loc_index.dims()[0]; int score_num = sampled_score_index.dims()[0]; @@ -432,6 +461,8 @@ class RpnTargetAssignKernel : public framework::OpKernel { AppendRpns(score_index, total_score_num, &sampled_score_index_unmap); AppendRpns(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox); AppendRpns(tgt_lbl, total_score_num, &sampled_tgtlbl); + AppendRpns(bbox_inside_weight, total_loc_num * 4, + &sampled_bbox_inside_weight); total_loc_num += loc_num; total_score_num += score_num; @@ -448,10 +479,12 @@ class RpnTargetAssignKernel : public framework::OpKernel { score_index->set_lod(loc_score); tgt_bbox->set_lod(lod_loc); tgt_lbl->set_lod(loc_score); + bbox_inside_weight->set_lod(lod_loc); loc_index->Resize({total_loc_num}); score_index->Resize({total_score_num}); tgt_bbox->Resize({total_loc_num, 4}); tgt_lbl->Resize({total_score_num, 1}); + bbox_inside_weight->Resize({total_loc_num, 4}); } }; @@ -514,6 +547,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { "TargetLabel", "(Tensor), The target labels of each anchor with shape " "[F + B, 1], F and B are sampled foreground and backgroud number."); + AddOutput("BBox_inside_weight", + "(Tensor), The bbox inside weight with shape " + "[F, 4], F is the sampled foreground number."); AddComment(R"DOC( This operator can be, for a given set of ground truth bboxes and the anchors, to assign classification and regression targets to each prediction. diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 1cfcbbb9c16..8026fa9398f 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -116,8 +116,8 @@ def rpn_target_assign(bbox_pred, Returns: tuple: A tuple(predicted_scores, predicted_location, target_label, - target_bbox) is returned. The predicted_scores and - predicted_location is the predicted result of the RPN. + target_bbox, bbox_inside_weight) is returned. The predicted_scores + and predicted_location is the predicted result of the RPN. The target_label and target_bbox is the ground truth, respectively. The predicted_location is a 2D Tensor with shape [F, 4], and the shape of target_bbox is same as the shape of @@ -126,6 +126,8 @@ def rpn_target_assign(bbox_pred, [F + B, 1], and the shape of target_label is same as the shape of the predicted_scores, B is the number of the background anchors, the F and B is depends on the input of this operator. + Bbox_inside_weight represents whether the predicted loc is fake_fg + or not and the shape is [F, 4]. Examples: .. code-block:: python @@ -138,7 +140,7 @@ def rpn_target_assign(bbox_pred, append_batch_size=False, dtype='float32') gt_boxes = layers.data(name='gt_boxes', shape=[10, 4], append_batch_size=False, dtype='float32') - loc_pred, score_pred, loc_target, score_target = + loc_pred, score_pred, loc_target, score_target, bbox_inside_weight = fluid.layers.rpn_target_assign(bbox_pred=bbox_pred, cls_logits=cls_logits, anchor_box=anchor_box, @@ -151,6 +153,7 @@ def rpn_target_assign(bbox_pred, score_index = helper.create_tmp_variable(dtype='int32') target_label = helper.create_tmp_variable(dtype='int32') target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype) + bbox_inside_weight = helper.create_tmp_variable(dtype=anchor_box.dtype) helper.append_op( type="rpn_target_assign", inputs={ @@ -163,7 +166,8 @@ def rpn_target_assign(bbox_pred, 'LocationIndex': loc_index, 'ScoreIndex': score_index, 'TargetLabel': target_label, - 'TargetBBox': target_bbox + 'TargetBBox': target_bbox, + 'BBox_inside_weight': bbox_inside_weight }, attrs={ 'rpn_batch_size_per_im': rpn_batch_size_per_im, @@ -178,13 +182,14 @@ def rpn_target_assign(bbox_pred, score_index.stop_gradient = True target_label.stop_gradient = True target_bbox.stop_gradient = True + bbox_inside_weight.stop_gradient = True cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1)) bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4)) predicted_cls_logits = nn.gather(cls_logits, score_index) predicted_bbox_pred = nn.gather(bbox_pred, loc_index) - return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox + return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight def detection_output(loc, diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 56129641ce5..b36b4272c73 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -301,7 +301,7 @@ class TestRpnTargetAssign(unittest.TestCase): dtype='float32', lod_level=1, append_batch_size=False) - pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign( + pred_scores, pred_loc, tgt_lbl, tgt_bbox, bbox_inside_weight = layers.rpn_target_assign( bbox_pred=bbox_pred, cls_logits=cls_logits, anchor_box=anchor_box, @@ -313,12 +313,14 @@ class TestRpnTargetAssign(unittest.TestCase): rpn_straddle_thresh=0.0, rpn_fg_fraction=0.5, rpn_positive_overlap=0.7, - rpn_negative_overlap=0.3) + rpn_negative_overlap=0.3, + use_random=False) self.assertIsNotNone(pred_scores) self.assertIsNotNone(pred_loc) self.assertIsNotNone(tgt_lbl) self.assertIsNotNone(tgt_bbox) + self.assertIsNotNone(bbox_inside_weight) assert pred_scores.shape[1] == 1 assert pred_loc.shape[1] == 4 assert pred_loc.shape[1] == tgt_bbox.shape[1] diff --git a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py index f63dbcd3d7f..fe1fa5e54de 100644 --- a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py @@ -50,8 +50,10 @@ def rpn_target_assign(anchor_by_gt_overlap, fg_inds, size=(len(fg_inds) - num_fg), replace=False) else: disable_inds = fg_inds[num_fg:] + labels[disable_inds] = -1 fg_inds = np.where(labels == 1)[0] + bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32) num_bg = rpn_batch_size_per_im - np.sum(labels == 1) bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0] @@ -59,18 +61,27 @@ def rpn_target_assign(anchor_by_gt_overlap, enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)] else: enable_inds = bg_inds[:num_bg] + + fg_fake_inds = np.array([], np.int32) + fg_value = np.array([fg_inds[0]], np.int32) + fake_num = 0 + for bg_id in enable_inds: + if bg_id in fg_inds: + fake_num += 1 + fg_fake_inds = np.hstack([fg_fake_inds, fg_value]) labels[enable_inds] = 0 + + bbox_inside_weight[fake_num:, :] = 1 fg_inds = np.where(labels == 1)[0] bg_inds = np.where(labels == 0)[0] - - loc_index = fg_inds - score_index = np.hstack((fg_inds, bg_inds)) + loc_index = np.hstack([fg_fake_inds, fg_inds]) + score_index = np.hstack([fg_inds, bg_inds]) labels = labels[score_index] assert not np.any(labels == -1), "Wrong labels with -1" - gt_inds = anchor_to_gt_argmax[fg_inds] + gt_inds = anchor_to_gt_argmax[loc_index] - return loc_index, score_index, labels, gt_inds + return loc_index, score_index, labels, gt_inds, bbox_inside_weight def get_anchor(n, c, h, w): @@ -123,9 +134,12 @@ def rpn_target_assign_in_python(all_anchors, gt_boxes_slice = gt_boxes_slice[not_crowd_inds] iou = _bbox_overlaps(inside_anchors, gt_boxes_slice) - loc_inds, score_inds, labels, gt_inds = rpn_target_assign( - iou, rpn_batch_size_per_im, rpn_positive_overlap, - rpn_negative_overlap, rpn_fg_fraction, use_random) + loc_inds, score_inds, labels, gt_inds, bbox_inside_weight = \ + rpn_target_assign(iou, rpn_batch_size_per_im, + rpn_positive_overlap, + rpn_negative_overlap, + rpn_fg_fraction, + use_random) # unmap to all anchor loc_inds = inds_inside[loc_inds] score_inds = inds_inside[score_inds] @@ -139,6 +153,7 @@ def rpn_target_assign_in_python(all_anchors, score_indexes = score_inds tgt_labels = labels tgt_bboxes = box_deltas + bbox_inside_weights = bbox_inside_weight else: loc_indexes = np.concatenate( [loc_indexes, loc_inds + i * anchor_num]) @@ -146,8 +161,10 @@ def rpn_target_assign_in_python(all_anchors, [score_indexes, score_inds + i * anchor_num]) tgt_labels = np.concatenate([tgt_labels, labels]) tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) + bbox_inside_weights = np.vstack([bbox_inside_weights, \ + bbox_inside_weight]) - return loc_indexes, score_indexes, tgt_bboxes, tgt_labels + return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights class TestRpnTargetAssignOp(OpTest): @@ -182,10 +199,12 @@ class TestRpnTargetAssignOp(OpTest): rpn_fg_fraction = 0.5 use_random = False - loc_index, score_index, tgt_bbox, labels = rpn_target_assign_in_python( - all_anchors, gt_boxes, is_crowd, im_info, lod, rpn_straddle_thresh, - rpn_batch_size_per_im, rpn_positive_overlap, rpn_negative_overlap, - rpn_fg_fraction, use_random) + loc_index, score_index, tgt_bbox, labels, bbox_inside_weights = \ + rpn_target_assign_in_python(all_anchors, gt_boxes, is_crowd, + im_info, lod, rpn_straddle_thresh, + rpn_batch_size_per_im, rpn_positive_overlap, + rpn_negative_overlap, + rpn_fg_fraction, use_random) labels = labels[:, np.newaxis] self.op_type = "rpn_target_assign" @@ -207,7 +226,8 @@ class TestRpnTargetAssignOp(OpTest): 'LocationIndex': loc_index.astype('int32'), 'ScoreIndex': score_index.astype('int32'), 'TargetBBox': tgt_bbox.astype('float32'), - 'TargetLabel': labels.astype('int32') + 'TargetLabel': labels.astype('int32'), + 'BBox_inside_weight': bbox_inside_weights.astype('float32') } def test_check_output(self): -- GitLab