From 9ed2f936f12f402679ac248ac169343408e5c2e8 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Sat, 15 Jun 2019 20:51:40 +0800 Subject: [PATCH] add target assign operator for supporting retinanet (#17893) * test=develop add target assign for retinanet * test=develop run ci * test=developp add test_layers * test=develop add APi.spec * test=develop alter round 1 * test=develop alter rpn_target_assign_op.cc * test=develop alter test_rpn_target_assign_op.py * test=develop alter rpn_target_assign_op.cc * test=develop alter API.spec * test=develop alter paddle/fluid/operators/detection/rpn_target_assign_op.cc * test=develop alter rpn_target_assign_op.cc * test=develop alter python/paddle/fluid/layers/detection.py * test=develop alter paddle/fluid/API.spec --- paddle/fluid/API.spec | 1 + .../detection/rpn_target_assign_op.cc | 469 +++++++++++++++++- python/paddle/fluid/layers/detection.py | 159 ++++++ .../fluid/tests/unittests/test_layers.py | 47 ++ .../unittests/test_rpn_target_assign_op.py | 159 ++++++ 5 files changed, 826 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index b7f2fa96a7..96727d402d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -348,6 +348,7 @@ paddle.fluid.layers.target_assign (ArgSpec(args=['input', 'matched_indices', 'ne paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0)), ('document', 'efae414c1137c7944d6174dd08c5347a')) paddle.fluid.layers.ssd_loss (ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)), ('document', '6d5028fd09d01ab82d296adc0ea95aee')) paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '1e164a56fe9376e18a56d22563d9f801')) +paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595')) paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a')) paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc')) paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)), ('document', '9c601df88b251f22e9311c52939948cd')) diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 0b8053e8d0..338954346c 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -202,21 +202,32 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, } // Reservoir Sampling - int fg_num = static_cast(rpn_fg_fraction * rpn_batch_size_per_im); - ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random); + int fg_num = 0; + if (rpn_fg_fraction > 0 && rpn_batch_size_per_im > 0) { + fg_num = static_cast(rpn_fg_fraction * rpn_batch_size_per_im); + ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random); + } else { + fg_num = static_cast(fg_inds_fake.size()); + } int fg_fake_num = static_cast(fg_inds_fake.size()); for (int64_t i = 0; i < fg_fake_num; ++i) { target_label[fg_inds_fake[i]] = 1; } - int bg_num = rpn_batch_size_per_im - fg_fake_num; for (int64_t i = 0; i < anchor_num; ++i) { if (anchor_to_gt_max_data[i] < rpn_negative_overlap) { bg_inds_fake.push_back(i); } } - ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random); - bg_num = static_cast(bg_inds_fake.size()); + int bg_num = 0; + if (rpn_fg_fraction > 0 && rpn_batch_size_per_im > 0) { + bg_num = rpn_batch_size_per_im - fg_fake_num; + ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random); + bg_num = static_cast(bg_inds_fake.size()); + } else { + bg_num = static_cast(bg_inds_fake.size()); + } + int fake_num = 0; for (int64_t i = 0; i < bg_num; ++i) { // fg fake found @@ -492,9 +503,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Anchor", "(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4]."); AddInput("GtBoxes", - "(LoDTensor) input groud-truth bbox with shape [K, 4]."); + "(LoDTensor) input ground-truth bbox with shape [K, 4]."); AddInput("IsCrowd", - "(LoDTensor) input which indicates groud-truth is crowd."); + "(LoDTensor) input which indicates ground-truth is crowd."); AddInput("ImInfo", "(LoDTensor) input image information with shape [N, 3]. " "N is the batch size, each image information includes height, " @@ -536,7 +547,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { "ScoreIndex", "(Tensor), The indexes of foreground and background anchors in all " "RPN anchors(The rest anchors are ignored). The shape of the " - "ScoreIndex is [F + B], F and B are sampled foreground and backgroud " + "ScoreIndex is [F + B], F and B are sampled foreground and background " " number."); AddOutput("TargetBBox", "(Tensor), The target bbox deltas with shape " @@ -544,7 +555,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput( "TargetLabel", "(Tensor), The target labels of each anchor with shape " - "[F + B, 1], F and B are sampled foreground and backgroud number."); + "[F + B, 1], F and B are sampled foreground and background number."); AddOutput("BBoxInsideWeight", "(Tensor), The bbox inside weight with shape " "[F, 4], F is the sampled foreground number."); @@ -573,6 +584,440 @@ negative do not contribute to the training objective. } }; +class RetinanetTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Anchor", + "(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4]."); + AddInput("GtBoxes", + "(LoDTensor) input ground-truth bbox with shape [K, 4]."); + AddInput("GtLabels", + "(LoDTensor) input ground-truth label with shape [K, 1]."); + AddInput("IsCrowd", + "(LoDTensor) input which indicates ground-truth is crowd."); + AddInput("ImInfo", + "(LoDTensor) input image information with shape [N, 3]. " + "N is the batch size, each image information includes height, " + "width and scale."); + AddAttr( + "positive_overlap", + "Minimum overlap required between an anchor and ground-truth " + "box for the (anchor, gt box) pair to be a positive example.") + .SetDefault(0.5); + AddAttr( + "negative_overlap", + "Maximum overlap allowed between an anchor and ground-truth " + "box for the (anchor, gt box) pair to be a negative examples.") + .SetDefault(0.4); + AddOutput( + "LocationIndex", + "(Tensor), The indexes of foreground anchors in all anchors, the " + "shape of the LocationIndex is [F], F depends on the value of input " + "tensor and attributes."); + AddOutput( + "ScoreIndex", + "(Tensor), The indexes of foreground and background anchors in all " + "RPN anchors(The rest anchors are ignored). The shape of the " + "ScoreIndex is [F + B], F and B are foreground and background " + " number."); + AddOutput("TargetBBox", + "(Tensor), The target bbox deltas with shape " + "[F, 4], F is the foreground number."); + AddOutput("TargetLabel", + "(Tensor), The target labels of each anchor with shape " + "[F + B, 1], F and B are foreground and background number."); + AddOutput("BBoxInsideWeight", + "(Tensor), The bbox inside weight with shape " + "[F, 4], F is the foreground number."); + AddOutput("ForegroundNumber", + "(Tensor), The foreground number. " + "[1, 1]."); + AddComment(R"DOC( + This layer can be, for given the Intersection-over-Union (IoU) overlap + between anchors and ground truth boxes, to assign classification and + regression targets to each anchor, these target labels are used for + train retinanet. + + Every anchor is assigned with a length C one-hot vector of + classification targets, and a 4-vector of box regression targets, + where C is the class number. The assignment rules are as followed: + + 1. Anchors are assigned to ground-truth boxes when: (i) it has the highest + IoU overlap with a ground-truth box, or (ii) it has an IoU overlap higher + than positive_overlap(0.5) with any ground-truth box. + + 2. Anchors are assigned to background when its IoU ratio is lower than + negative_overlap (0.4) for all ground-truth boxes. + + When an anchor is assigned with a ground-truth box which is the i-th category, + the i-th entry in its C vector of targets is set to 1 and all other entries + are set to 0. When an anchor is assigned with background, all entries are set + to 0. Anchors that are not assigned do not contribute to the training + objective. The regression targets are the encoded ground-truth boxes + associated with the assigned anchors. + +)DOC"); + } +}; + +class RetinanetTargetAssignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("Anchor"), + "Input(Anchor) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasInput("GtBoxes"), + "Input(GtBoxes) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasInput("GtLabels"), + "Input(GtLabels) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasInput("IsCrowd"), + "Input(Anchor) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasInput("ImInfo"), + "Input(ImInfo) of RetinanetTargetAssignOp should not be null"); + + PADDLE_ENFORCE( + ctx->HasOutput("LocationIndex"), + "Output(LocationIndex) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("ScoreIndex"), + "Output(ScoreIndex) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("TargetLabel"), + "Output(TargetLabel) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("TargetBBox"), + "Output(TargetBBox) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("BBoxInsideWeight"), + "Output(BBoxInsideWeight) of RetinanetTargetAssignOp should " + "not be null"); + PADDLE_ENFORCE(ctx->HasOutput("ForegroundNumber"), + "Output(ForegroundNumber) of RetinanetTargetAssignOp should " + "not be null"); + + auto anchor_dims = ctx->GetInputDim("Anchor"); + auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); + auto gt_labels_dims = ctx->GetInputDim("GtLabels"); + auto im_info_dims = ctx->GetInputDim("ImInfo"); + + PADDLE_ENFORCE_EQ(anchor_dims.size(), 2, + "The rank of Input(Anchor) must be 2."); + PADDLE_ENFORCE_EQ(gt_boxes_dims.size(), 2, + "The rank of Input(GtBoxes) must be 2."); + PADDLE_ENFORCE_EQ(gt_labels_dims.size(), 2, + "The rank of Input(GtLabels) must be 2."); + PADDLE_ENFORCE_EQ(im_info_dims.size(), 2, + "The rank of Input(ImInfo) must be 2."); + + ctx->SetOutputDim("LocationIndex", {gt_labels_dims[0]}); + ctx->SetOutputDim("ScoreIndex", {gt_labels_dims[0]}); + ctx->SetOutputDim("TargetBBox", {gt_labels_dims[0], 4}); + ctx->SetOutputDim("TargetLabel", {gt_labels_dims[0], 1}); + ctx->SetOutputDim("BBoxInsideWeight", {gt_labels_dims[0], 4}); + ctx->SetOutputDim("ForegroundNumber", {gt_labels_dims[0], 1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input("Anchor")->type(), + platform::CPUPlace()); + } +}; + +template +std::vector FilterCrowdGtBoxLabel( + const platform::CPUDeviceContext& context, Tensor* gt_boxes, + Tensor* gt_labels, Tensor* is_crowd) { + int gt_num = gt_boxes->dims()[0]; + std::vector not_crowd_inds; + auto* is_crowd_data = is_crowd->data(); + for (int i = 0; i < gt_num; ++i) { + if (is_crowd_data[i] == 0) { + not_crowd_inds.emplace_back(i); + } + } + int ncrowd_num = not_crowd_inds.size(); + Tensor ncrowd_gt_boxes, ncrowd_gt_labels; + T* ncrowd_gt_boxes_data = + ncrowd_gt_boxes.mutable_data({ncrowd_num, 4}, context.GetPlace()); + int* ncrowd_gt_labels_data = + ncrowd_gt_labels.mutable_data({ncrowd_num, 1}, context.GetPlace()); + Gather(gt_boxes->data(), 4, not_crowd_inds.data(), ncrowd_num, + ncrowd_gt_boxes_data); + Gather(gt_labels->data(), 1, not_crowd_inds.data(), ncrowd_num, + ncrowd_gt_labels_data); + std::vector res; + res.emplace_back(ncrowd_gt_boxes); + res.emplace_back(ncrowd_gt_labels); + return res; +} + +template +std::vector GetAllFgBgGt(const platform::CPUDeviceContext& ctx, + const Tensor& anchor_by_gt_overlap, + const Tensor& ncrowd_gt_labels, + const float positive_overlap, + const float negative_overlap, + std::minstd_rand engine) { + auto* anchor_by_gt_overlap_data = anchor_by_gt_overlap.data(); + int anchor_num = anchor_by_gt_overlap.dims()[0]; + int gt_num = anchor_by_gt_overlap.dims()[1]; + + std::vector fg_inds; + std::vector bg_inds; + std::vector gt_inds; + std::vector tgt_lbl; + std::vector fg_fake; + std::vector bbox_inside_weight; + // Calculate the max IoU between anchors and gt boxes + // Map from anchor to gt box that has highest overlap + auto place = ctx.GetPlace(); + Tensor anchor_to_gt_max, anchor_to_gt_argmax, gt_to_anchor_max; + anchor_to_gt_max.mutable_data({anchor_num}, place); + int* argmax = anchor_to_gt_argmax.mutable_data({anchor_num}, place); + gt_to_anchor_max.mutable_data({gt_num}, place); + + auto anchor_by_gt_overlap_et = + framework::EigenMatrix::From(anchor_by_gt_overlap); + auto anchor_to_gt_max_et = + framework::EigenVector::Flatten(anchor_to_gt_max); + auto gt_to_anchor_max_et = + framework::EigenVector::Flatten(gt_to_anchor_max); + auto anchor_to_gt_argmax_et = + framework::EigenVector::Flatten(anchor_to_gt_argmax); + anchor_to_gt_max_et = + anchor_by_gt_overlap_et.maximum(Eigen::DSizes(1)); + anchor_to_gt_argmax_et = + anchor_by_gt_overlap_et.argmax(1).template cast(); + gt_to_anchor_max_et = + anchor_by_gt_overlap_et.maximum(Eigen::DSizes(0)); + + ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max, -1, + -1, positive_overlap, negative_overlap, &fg_inds, &bg_inds, + &tgt_lbl, &fg_fake, &bbox_inside_weight, engine, false); + const int* gt_labels_data = ncrowd_gt_labels.data(); + int64_t fg_num = fg_inds.size(); + for (int64_t i = 0; i < fg_num; ++i) { + int gt_idx = argmax[fg_inds[i]]; + tgt_lbl[i] = gt_labels_data[gt_idx]; + } + + int bg_num = bg_inds.size(); + int fg_fake_num = fg_fake.size(); + gt_inds.reserve(fg_fake_num); + for (int i = 0; i < fg_fake_num; ++i) { + gt_inds.emplace_back(argmax[fg_fake[i]]); + } + + Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t, bbox_inside_weight_t; + Tensor fg_num_t; + int* loc_index_data = loc_index_t.mutable_data({fg_fake_num}, place); + int* score_index_data = + score_index_t.mutable_data({fg_num + bg_num}, place); + int* tgt_lbl_data = tgt_lbl_t.mutable_data({fg_num + bg_num}, place); + int* gt_inds_data = gt_inds_t.mutable_data({fg_fake_num}, place); + int* fg_num_data = fg_num_t.mutable_data({1}, place); + T* bbox_inside_weight_data = + bbox_inside_weight_t.mutable_data({fg_fake_num, 4}, place); + std::copy(fg_fake.begin(), fg_fake.end(), loc_index_data); + std::copy(fg_inds.begin(), fg_inds.end(), score_index_data); + std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num); + std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data); + std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data); + std::copy(bbox_inside_weight.begin(), bbox_inside_weight.end(), + bbox_inside_weight_data); + fg_num_data[0] = fg_fake.size() + 1; + std::vector loc_score_tgtlbl_gt; + loc_score_tgtlbl_gt.emplace_back(loc_index_t); + loc_score_tgtlbl_gt.emplace_back(score_index_t); + loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t); + loc_score_tgtlbl_gt.emplace_back(gt_inds_t); + loc_score_tgtlbl_gt.emplace_back(bbox_inside_weight_t); + loc_score_tgtlbl_gt.emplace_back(fg_num_t); + + return loc_score_tgtlbl_gt; +} + +template +class RetinanetTargetAssignKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* anchor = context.Input("Anchor"); // (H*W*A) * 4 + auto* gt_boxes = context.Input("GtBoxes"); + auto* gt_labels = context.Input("GtLabels"); + auto* is_crowd = context.Input("IsCrowd"); + auto* im_info = context.Input("ImInfo"); + + auto* loc_index = context.Output("LocationIndex"); + auto* score_index = context.Output("ScoreIndex"); + auto* tgt_bbox = context.Output("TargetBBox"); + auto* tgt_lbl = context.Output("TargetLabel"); + auto* bbox_inside_weight = context.Output("BBoxInsideWeight"); + auto* fg_num = context.Output("ForegroundNumber"); + + PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL, + "RetinanetTargetAssignOp gt_boxes needs 1 level of LoD"); + PADDLE_ENFORCE_EQ(gt_labels->lod().size(), 1UL, + "RetinanetTargetAssignOp gt_boxes needs 1 level of LoD"); + PADDLE_ENFORCE_EQ(is_crowd->lod().size(), 1UL, + "RetinanetTargetAssignOp is_crowd needs 1 level of LoD"); + + int64_t anchor_num = static_cast(anchor->dims()[0]); + int64_t batch_num = static_cast(gt_boxes->lod().back().size() - 1); + + float positive_overlap = context.Attr("positive_overlap"); + float negative_overlap = context.Attr("negative_overlap"); + + int64_t max_num = batch_num * anchor_num; + auto place = context.GetPlace(); + + loc_index->mutable_data({max_num}, place); + score_index->mutable_data({max_num}, place); + tgt_bbox->mutable_data({max_num, 4}, place); + tgt_lbl->mutable_data({max_num, 1}, place); + bbox_inside_weight->mutable_data({max_num, 4}, place); + fg_num->mutable_data({batch_num, 1}, place); + auto& dev_ctx = context.device_context(); + + std::random_device rnd; + std::minstd_rand engine; + int seed = rnd(); + engine.seed(seed); + + framework::LoD lod_loc, loc_score, lod_fg; + std::vector lod0_loc(1, 0); + std::vector lod0_score(1, 0); + std::vector lod0_fg(1, 0); + + int total_loc_num = 0; + int total_score_num = 0; + int total_fg_num = 0; + auto gt_boxes_lod = gt_boxes->lod().back(); + auto gt_labels_lod = gt_labels->lod().back(); + auto is_crowd_lod = is_crowd->lod().back(); + for (int i = 0; i < batch_num; ++i) { + Tensor gt_boxes_slice = + gt_boxes->Slice(gt_boxes_lod[i], gt_boxes_lod[i + 1]); + Tensor gt_labels_slice = + gt_labels->Slice(gt_labels_lod[i], gt_labels_lod[i + 1]); + Tensor is_crowd_slice = + is_crowd->Slice(is_crowd_lod[i], is_crowd_lod[i + 1]); + Tensor im_info_slice = im_info->Slice(i, i + 1); + auto* im_info_data = im_info_slice.data(); + auto im_height = im_info_data[0]; + auto im_width = im_info_data[1]; + auto im_scale = im_info_data[2]; + + // Filter straddle anchor + std::vector filter_output = + FilterStraddleAnchor(dev_ctx, anchor, -1, im_height, im_width); + Tensor inds_inside = filter_output[0]; + Tensor inside_anchor = filter_output[1]; + + // Filter crowd gt + std::vector ncrowd_output = FilterCrowdGtBoxLabel( + dev_ctx, >_boxes_slice, >_labels_slice, &is_crowd_slice); + Tensor ncrowd_gt_boxes = ncrowd_output[0]; + Tensor ncrowd_gt_labels = ncrowd_output[1]; + + auto ncrowd_gt_boxes_et = + framework::EigenTensor::From(ncrowd_gt_boxes); + ncrowd_gt_boxes_et = ncrowd_gt_boxes_et * im_scale; + + Tensor anchor_by_gt_overlap; + anchor_by_gt_overlap.mutable_data( + {inside_anchor.dims()[0], ncrowd_gt_boxes.dims()[0]}, place); + BboxOverlaps(inside_anchor, ncrowd_gt_boxes, &anchor_by_gt_overlap); + + auto loc_score_tgtlbl_gt = + GetAllFgBgGt(dev_ctx, anchor_by_gt_overlap, ncrowd_gt_labels, + positive_overlap, negative_overlap, engine); + + Tensor sampled_loc_index = loc_score_tgtlbl_gt[0]; + Tensor sampled_score_index = loc_score_tgtlbl_gt[1]; + Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2]; + Tensor sampled_gt_index = loc_score_tgtlbl_gt[3]; + Tensor sampled_bbox_inside_weight = loc_score_tgtlbl_gt[4]; + Tensor sampled_fg_num = loc_score_tgtlbl_gt[5]; + + int loc_num = sampled_loc_index.dims()[0]; + int score_num = sampled_score_index.dims()[0]; + // unmap to all anchor + Tensor sampled_loc_index_unmap, sampled_score_index_unmap; + sampled_loc_index_unmap.mutable_data({loc_num}, place); + sampled_score_index_unmap.mutable_data({score_num}, place); + Gather(inds_inside.data(), 1, sampled_loc_index.data(), + loc_num, sampled_loc_index_unmap.data()); + Gather(inds_inside.data(), 1, sampled_score_index.data(), + score_num, sampled_score_index_unmap.data()); + + // get target bbox deltas + Tensor sampled_anchor, sampled_gt, sampled_tgt_bbox; + auto* sampled_anchor_data = + sampled_anchor.mutable_data({loc_num, 4}, place); + auto* sampled_gt_data = sampled_gt.mutable_data({loc_num, 4}, place); + Gather(anchor->data(), 4, sampled_loc_index_unmap.data(), + loc_num, sampled_anchor_data); + Gather(ncrowd_gt_boxes.data(), 4, sampled_gt_index.data(), + loc_num, sampled_gt_data); + sampled_tgt_bbox.mutable_data({loc_num, 4}, place); + BoxToDelta(loc_num, sampled_anchor, sampled_gt, nullptr, false, + &sampled_tgt_bbox); + + // Add anchor offset + int anchor_offset = i * anchor_num; + auto sampled_loc_index_unmap_et = + framework::EigenTensor::From(sampled_loc_index_unmap); + sampled_loc_index_unmap_et = sampled_loc_index_unmap_et + anchor_offset; + auto sampled_score_index_unmap_et = + framework::EigenTensor::From(sampled_score_index_unmap); + sampled_score_index_unmap_et = + sampled_score_index_unmap_et + anchor_offset; + AppendRpns(loc_index, total_loc_num, &sampled_loc_index_unmap); + AppendRpns(score_index, total_score_num, &sampled_score_index_unmap); + AppendRpns(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox); + AppendRpns(tgt_lbl, total_score_num, &sampled_tgtlbl); + AppendRpns(bbox_inside_weight, total_loc_num * 4, + &sampled_bbox_inside_weight); + AppendRpns(fg_num, total_fg_num, &sampled_fg_num); + + total_loc_num += loc_num; + total_score_num += score_num; + total_fg_num += 1; + lod0_loc.emplace_back(total_loc_num); + lod0_score.emplace_back(total_score_num); + lod0_fg.emplace_back(total_fg_num); + } + + PADDLE_ENFORCE_LE(total_loc_num, max_num); + PADDLE_ENFORCE_LE(total_score_num, max_num); + PADDLE_ENFORCE_LE(total_fg_num, batch_num); + + lod_loc.emplace_back(lod0_loc); + loc_score.emplace_back(lod0_score); + lod_fg.emplace_back(lod0_fg); + loc_index->set_lod(lod_loc); + score_index->set_lod(loc_score); + tgt_bbox->set_lod(lod_loc); + tgt_lbl->set_lod(loc_score); + bbox_inside_weight->set_lod(lod_loc); + fg_num->set_lod(lod_fg); + loc_index->Resize({total_loc_num}); + score_index->Resize({total_score_num}); + tgt_bbox->Resize({total_loc_num, 4}); + tgt_lbl->Resize({total_score_num, 1}); + bbox_inside_weight->Resize({total_loc_num, 4}); + fg_num->Resize({total_fg_num, 1}); + } +}; + } // namespace operators } // namespace paddle @@ -582,3 +1027,9 @@ REGISTER_OPERATOR(rpn_target_assign, ops::RpnTargetAssignOp, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(rpn_target_assign, ops::RpnTargetAssignKernel, ops::RpnTargetAssignKernel); +REGISTER_OPERATOR(retinanet_target_assign, ops::RetinanetTargetAssignOp, + ops::RetinanetTargetAssignOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(retinanet_target_assign, + ops::RetinanetTargetAssignKernel, + ops::RetinanetTargetAssignKernel); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 6ec46c5c90..d5225c3074 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -39,6 +39,7 @@ __all__ = [ 'detection_output', 'ssd_loss', 'rpn_target_assign', + 'retinanet_target_assign', 'anchor_generator', 'roi_perspective_transform', 'generate_proposal_labels', @@ -57,6 +58,164 @@ __all__ = [ ] +def retinanet_target_assign(bbox_pred, + cls_logits, + anchor_box, + anchor_var, + gt_boxes, + gt_labels, + is_crowd, + im_info, + num_classes=1, + positive_overlap=0.5, + negative_overlap=0.4): + """ + **Target Assign Layer for Retinanet .** + + This layer can be, for given the Intersection-over-Union (IoU) overlap + between anchors and ground truth boxes, to assign classification and + regression targets to each anchor, these target labels are used for training + retinanet. Every anchor is assigned with a length :attr:`num_classes` + one-hot vector of classification targets, and a 4-vector of box regression + targets. The assignment rules are as followed: + + 1. Anchors are assigned to ground-truth boxes when: (i) it has the highest + IoU overlap with a ground-truth box, or (ii) it has an IoU overlap higher + than positive_overlap(0.5) with any ground-truth box. + + 2. Anchors are assigned to background when its IoU ratio is lower than + negative_overlap (0.4) for all ground-truth boxes. + + When an anchor is assigned with a ground-truth box which is the i-th category, + the i-th entry in its C vector of targets is set to 1 and all other entries + are set to 0. When an anchor is assigned with background, all entries are set + to 0. Anchors that are not assigned do not contribute to the training + objective. The regression targets are the encoded ground-truth boxes + associated with the assigned anchors. + + Args: + bbox_pred(Variable): A 3-D Tensor with shape [N, M, 4] represents the + predicted locations of M bounding bboxes. N is the batch size, + and each bounding box has four coordinate values and the layout + is [xmin, ymin, xmax, ymax]. + cls_logits(Variable): A 3-D Tensor with shape [N, M, C] represents the + predicted confidence predictions. N is the batch size, C is the + number of classes (excluding background), M is number of bounding boxes. + anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes, + each box is represented as [xmin, ymin, xmax, ymax], + [xmin, ymin] is the left top coordinate of the anchor box, + if the input is image feature map, they are close to the origin + of the coordinate system. [xmax, ymax] is the right bottom + coordinate of the anchor box. + anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded + variances of anchors. + gt_boxes(Variable): The ground-truth bounding boxes (bboxes) are a 2D + LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth + bboxes of mini-batch input. + gt_labels(variable): The ground-truth labels are a 2D LoDTensor with + shape [Ng, 1], Ng is the total number of ground-truth labels of + mini-batch input. + is_crowd(Variable): A 1-D LoDTensor which indicates ground-truth is crowd. + im_info(Variable): A 2-D LoDTensor with shape [N, 3]. N is the batch size, + 3 is the height, width and scale. + num_classes(int32): The number of classes. + positive_overlap(float): Minimum overlap required between an anchor + and ground-truth box for the (anchor, gt box) pair to be a positive + example. + negative_overlap(float): Maximum overlap allowed between an anchor + and ground-truth box for the (anchor, gt box) pair to be a negative + examples. + + Returns: + tuple: + A tuple(predicted_scores, predicted_location, target_label, + target_bbox, bbox_inside_weight, fg_num) is returned. The + predicted_scores and predicted_location are the predicted result + of the retinanet.The target_label and target_bbox are the ground + truth, respectively. The predicted_location is a 2D Tensor with + shape [F, 4], and the shape of target_bbox is same as the shape of + the predicted_location, F is the number of the foreground + anchors. The predicted_scores is a 2D Tensor with shape + [F + B, C], and the shape of target_label is [F + B, 1], B is the + number of the background anchors, the F and B is depends on the + input of this operator. Bbox_inside_weight represents whether the + predicted location is fake foreground or not and the shape is [F, 4]. + Fg_num is the foreground number (including fake foreground) which + is needed by focal loss. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + bbox_pred = layers.data(name='bbox_pred', shape=[1, 100, 4], + append_batch_size=False, dtype='float32') + cls_logits = layers.data(name='cls_logits', shape=[1, 100, 10], + append_batch_size=False, dtype='float32') + anchor_box = layers.data(name='anchor_box', shape=[100, 4], + append_batch_size=False, dtype='float32') + anchor_var = layers.data(name='anchor_var', shape=[100, 4], + append_batch_size=False, dtype='float32') + gt_boxes = layers.data(name='gt_boxes', shape=[10, 4], + append_batch_size=False, dtype='float32') + gt_labels = layers.data(name='gt_labels', shape=[10, 1], + append_batch_size=False, dtype='float32') + is_crowd = fluid.layers.data(name='is_crowd', shape=[1], + append_batch_size=False, dtype='float32') + im_info = fluid.layers.data(name='im_infoss', shape=[1, 3], + append_batch_size=False, dtype='float32') + loc_pred, score_pred, loc_target, score_target, bbox_inside_weight, fg_num = + fluid.layers.retinanet_target_assign(bbox_pred, cls_logits, anchor_box, + anchor_var, gt_boxes, gt_labels, is_crowd, im_info, 10) + + """ + + helper = LayerHelper('retinanet_target_assign', **locals()) + # Assign target label to anchors + loc_index = helper.create_variable_for_type_inference(dtype='int32') + score_index = helper.create_variable_for_type_inference(dtype='int32') + target_label = helper.create_variable_for_type_inference(dtype='int32') + target_bbox = helper.create_variable_for_type_inference( + dtype=anchor_box.dtype) + bbox_inside_weight = helper.create_variable_for_type_inference( + dtype=anchor_box.dtype) + fg_num = helper.create_variable_for_type_inference(dtype='int32') + helper.append_op( + type="retinanet_target_assign", + inputs={ + 'Anchor': anchor_box, + 'GtBoxes': gt_boxes, + 'GtLabels': gt_labels, + 'IsCrowd': is_crowd, + 'ImInfo': im_info + }, + outputs={ + 'LocationIndex': loc_index, + 'ScoreIndex': score_index, + 'TargetLabel': target_label, + 'TargetBBox': target_bbox, + 'BBoxInsideWeight': bbox_inside_weight, + 'ForegroundNumber': fg_num + }, + attrs={ + 'positive_overlap': positive_overlap, + 'negative_overlap': negative_overlap + }) + + loc_index.stop_gradient = True + score_index.stop_gradient = True + target_label.stop_gradient = True + target_bbox.stop_gradient = True + bbox_inside_weight.stop_gradient = True + fg_num.stop_gradient = True + + cls_logits = nn.reshape(x=cls_logits, shape=(-1, num_classes)) + bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4)) + predicted_cls_logits = nn.gather(cls_logits, score_index) + predicted_bbox_pred = nn.gather(bbox_pred, loc_index) + + return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight, fg_num + + def rpn_target_assign(bbox_pred, cls_logits, anchor_box, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e6277649e5..7fbadd40c3 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2024,6 +2024,53 @@ class TestBook(LayerTest): trans_std=0.1) return (out) + def test_retinanet_target_assign(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + bbox_pred = layers.data( + name='bbox_pred', + shape=[1, 100, 4], + append_batch_size=False, + dtype='float32') + cls_logits = layers.data( + name='cls_logits', + shape=[1, 100, 10], + append_batch_size=False, + dtype='float32') + anchor_box = layers.data( + name='anchor_box', + shape=[100, 4], + append_batch_size=False, + dtype='float32') + anchor_var = layers.data( + name='anchor_var', + shape=[100, 4], + append_batch_size=False, + dtype='float32') + gt_boxes = layers.data( + name='gt_boxes', + shape=[10, 4], + append_batch_size=False, + dtype='float32') + gt_labels = layers.data( + name='gt_labels', + shape=[10, 1], + append_batch_size=False, + dtype='float32') + is_crowd = layers.data( + name='is_crowd', + shape=[1], + append_batch_size=False, + dtype='float32') + im_info = layers.data( + name='im_info', + shape=[1, 3], + append_batch_size=False, + dtype='float32') + return (layers.retinanet_target_assign( + bbox_pred, cls_logits, anchor_box, anchor_var, gt_boxes, + gt_labels, is_crowd, im_info, 10)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py index 1a2c9bb5f4..3dba961dc9 100644 --- a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py @@ -167,6 +167,105 @@ def rpn_target_assign_in_python(all_anchors, return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights +def retinanet_target_assign(anchor_by_gt_overlap, gt_labels, positive_overlap, + negative_overlap): + anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(axis=1) + anchor_to_gt_max = anchor_by_gt_overlap[np.arange( + anchor_by_gt_overlap.shape[0]), anchor_to_gt_argmax] + + gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(axis=0) + gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, np.arange( + anchor_by_gt_overlap.shape[1])] + anchors_with_max_overlap = np.where( + anchor_by_gt_overlap == gt_to_anchor_max)[0] + + labels = np.ones((anchor_by_gt_overlap.shape[0], ), dtype=np.int32) * -1 + labels[anchors_with_max_overlap] = 1 + labels[anchor_to_gt_max >= positive_overlap] = 1 + + fg_inds = np.where(labels == 1)[0] + bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32) + + bg_inds = np.where(anchor_to_gt_max < negative_overlap)[0] + enable_inds = bg_inds + + fg_fake_inds = np.array([], np.int32) + fg_value = np.array([fg_inds[0]], np.int32) + fake_num = 0 + for bg_id in enable_inds: + if bg_id in fg_inds: + fake_num += 1 + fg_fake_inds = np.hstack([fg_fake_inds, fg_value]) + labels[enable_inds] = 0 + + bbox_inside_weight[fake_num:, :] = 1 + fg_inds = np.where(labels == 1)[0] + bg_inds = np.where(labels == 0)[0] + loc_index = np.hstack([fg_fake_inds, fg_inds]) + score_index = np.hstack([fg_inds, bg_inds]) + score_index_tmp = np.hstack([fg_inds]) + labels = labels[score_index] + + gt_inds = anchor_to_gt_argmax[loc_index] + label_inds = anchor_to_gt_argmax[score_index_tmp] + labels[0:len(fg_inds)] = np.squeeze(gt_labels[label_inds]) + fg_num = len(fg_fake_inds) + len(fg_inds) + 1 + assert not np.any(labels == -1), "Wrong labels with -1" + return loc_index, score_index, labels, gt_inds, bbox_inside_weight, fg_num + + +def retinanet_target_assign_in_python(all_anchors, gt_boxes, gt_labels, + is_crowd, im_info, lod, positive_overlap, + negative_overlap): + anchor_num = all_anchors.shape[0] + batch_size = len(lod) - 1 + for i in range(batch_size): + im_scale = im_info[i][2] + + inds_inside = np.arange(all_anchors.shape[0]) + inside_anchors = all_anchors + b, e = lod[i], lod[i + 1] + gt_boxes_slice = gt_boxes[b:e, :] * im_scale + gt_labels_slice = gt_labels[b:e, :] + is_crowd_slice = is_crowd[b:e] + + not_crowd_inds = np.where(is_crowd_slice == 0)[0] + gt_boxes_slice = gt_boxes_slice[not_crowd_inds] + gt_labels_slice = gt_labels_slice[not_crowd_inds] + iou = _bbox_overlaps(inside_anchors, gt_boxes_slice) + + loc_inds, score_inds, labels, gt_inds, bbox_inside_weight, fg_num = \ + retinanet_target_assign(iou, gt_labels_slice, + positive_overlap, negative_overlap) + # unmap to all anchor + loc_inds = inds_inside[loc_inds] + score_inds = inds_inside[score_inds] + + sampled_gt = gt_boxes_slice[gt_inds] + sampled_anchor = all_anchors[loc_inds] + box_deltas = _box_to_delta(sampled_anchor, sampled_gt, [1., 1., 1., 1.]) + + if i == 0: + loc_indexes = loc_inds + score_indexes = score_inds + tgt_labels = labels + tgt_bboxes = box_deltas + bbox_inside_weights = bbox_inside_weight + fg_nums = [[fg_num]] + else: + loc_indexes = np.concatenate( + [loc_indexes, loc_inds + i * anchor_num]) + score_indexes = np.concatenate( + [score_indexes, score_inds + i * anchor_num]) + tgt_labels = np.concatenate([tgt_labels, labels]) + tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) + bbox_inside_weights = np.vstack([bbox_inside_weights, \ + bbox_inside_weight]) + fg_nums = np.concatenate([fg_nums, [[fg_num]]]) + + return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights, fg_nums + + class TestRpnTargetAssignOp(OpTest): def setUp(self): n, c, h, w = 2, 4, 14, 14 @@ -234,5 +333,65 @@ class TestRpnTargetAssignOp(OpTest): self.check_output() +class TestRetinanetTargetAssignOp(OpTest): + def setUp(self): + n, c, h, w = 2, 4, 14, 14 + all_anchors = get_anchor(n, c, h, w) + gt_num = 10 + all_anchors = all_anchors.reshape(-1, 4) + anchor_num = all_anchors.shape[0] + + images_shape = [[64, 64], [64, 64]] + groundtruth, lod = _generate_groundtruth(images_shape, 3, 4) + lod = [0, 4, 8] + + im_info = np.ones((len(images_shape), 3)).astype(np.float32) + for i in range(len(images_shape)): + im_info[i, 0] = images_shape[i][0] + im_info[i, 1] = images_shape[i][1] + im_info[i, 2] = 0.8 #scale + gt_boxes = np.vstack([v['boxes'] for v in groundtruth]) + is_crowd = np.hstack([v['is_crowd'] for v in groundtruth]) + gt_labels = np.vstack([ + v['gt_classes'].reshape(len(v['gt_classes']), 1) + for v in groundtruth + ]) + gt_labels = gt_labels.reshape(len(gt_labels), 1) + all_anchors = all_anchors.astype('float32') + gt_boxes = gt_boxes.astype('float32') + gt_labels = gt_labels.astype('int32') + + positive_overlap = 0.5 + negative_overlap = 0.4 + + loc_index, score_index, tgt_bbox, labels, bbox_inside_weights, fg_num = \ + retinanet_target_assign_in_python(all_anchors, gt_boxes, gt_labels, is_crowd, + im_info, lod, positive_overlap, negative_overlap) + labels = labels[:, np.newaxis] + self.op_type = "retinanet_target_assign" + self.inputs = { + 'Anchor': all_anchors, + 'GtBoxes': (gt_boxes, [[4, 4]]), + 'GtLabels': (gt_labels, [[4, 4]]), + 'IsCrowd': (is_crowd, [[4, 4]]), + 'ImInfo': (im_info, [[1, 1]]) + } + self.attrs = { + 'positive_overlap': positive_overlap, + 'negative_overlap': negative_overlap + } + self.outputs = { + 'LocationIndex': loc_index.astype('int32'), + 'ScoreIndex': score_index.astype('int32'), + 'TargetBBox': tgt_bbox.astype('float32'), + 'TargetLabel': labels.astype('int32'), + 'BBoxInsideWeight': bbox_inside_weights.astype('float32'), + 'ForegroundNumber': fg_num.astype('int32') + } + + def test_check_output(self): + self.check_output() + + if __name__ == '__main__': unittest.main() -- GitLab