diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 7ae0f445a8d377d2540af2a577bac8186617f89c..085e09b502a1fe3383081d4488f972ac6fcdb588 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -299,6 +299,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', ' paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')) paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3)) paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)) +paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index a44d84cd7b99107fef09a6b4dfa60172fabd718b..1301c8ae2b145298b18f68f86dadd3c5cbe4271a 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -29,6 +29,6 @@ target_assign_op.cu) detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc polygon_box_transform_op.cu) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) - -# Export local libraries to parent +detection_library(generate_proposals_op SRCS generate_proposals_op.cc) +#Export local libraries to parent set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d29b0153389574de8992b93ac6795e91556af870 --- /dev/null +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -0,0 +1,485 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/gather.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +struct AppendProposalsFunctor { + LoDTensor *out_; + int64_t offset_; + Tensor *to_add_; + + AppendProposalsFunctor(LoDTensor *out, int64_t offset, Tensor *to_add) + : out_(out), offset_(offset), to_add_(to_add) {} + + template + void operator()() const { + auto *out_data = out_->data(); + auto *to_add_data = to_add_->data(); + memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T)); + } +}; + +class GenerateProposalsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Scores"), "Input(Scores) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("BboxDeltas"), + "Input(BboxDeltas) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("ImInfo"), "Input(ImInfo) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Anchors"), + "Input(Anchors) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Variances"), + "Input(Variances) shouldn't be null."); + + auto scores_dims = ctx->GetInputDim("Scores"); + auto bbox_deltas_dims = ctx->GetInputDim("BboxDeltas"); + auto im_info_dims = ctx->GetInputDim("ImInfo"); + auto anchors_dims = ctx->GetInputDim("Anchors"); + auto variances_dims = ctx->GetInputDim("Variances"); + + ctx->SetOutputDim("RpnRois", {-1, 4}); + ctx->SetOutputDim("RpnRoiProbs", {-1, 1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Anchors")->type()), + platform::CPUPlace()); + } +}; + +template +void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, + Tensor *bbox_deltas, Tensor *variances, Tensor *proposals) { + T *proposals_data = proposals->mutable_data(ctx.GetPlace()); + + int64_t row = all_anchors->dims()[0]; + int64_t len = all_anchors->dims()[1]; + + auto *bbox_deltas_data = bbox_deltas->data(); + auto *anchor_data = all_anchors->data(); + const T *variances_data = nullptr; + if (variances) { + variances_data = variances->data(); + } + + for (int64_t i = 0; i < row; ++i) { + T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len]; + T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1]; + + T anchor_center_x = (anchor_data[i * len + 2] + anchor_data[i * len]) / 2; + T anchor_center_y = + (anchor_data[i * len + 3] + anchor_data[i * len + 1]) / 2; + + T bbox_center_x = 0, bbox_center_y = 0; + T bbox_width = 0, bbox_height = 0; + + if (variances) { + bbox_center_x = + variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width + + anchor_center_x; + bbox_center_y = variances_data[i * len + 1] * + bbox_deltas_data[i * len + 1] * anchor_height + + anchor_center_y; + bbox_width = std::exp(variances_data[i * len + 2] * + bbox_deltas_data[i * len + 2]) * + anchor_width; + bbox_height = std::exp(variances_data[i * len + 3] * + bbox_deltas_data[i * len + 3]) * + anchor_height; + } else { + bbox_center_x = + bbox_deltas_data[i * len] * anchor_width + anchor_center_x; + bbox_center_y = + bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y; + bbox_width = std::exp(bbox_deltas_data[i * len + 2]) * anchor_width; + bbox_height = std::exp(bbox_deltas_data[i * len + 3]) * anchor_height; + } + + proposals_data[i * len] = bbox_center_x - bbox_width / 2; + proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2; + proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2; + proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2; + } + // return proposals; +} + +template +void ClipTiledBoxes(const platform::DeviceContext &ctx, const Tensor &im_info, + Tensor *boxes) { + T *boxes_data = boxes->mutable_data(ctx.GetPlace()); + const T *im_info_data = im_info.data(); + for (int64_t i = 0; i < boxes->numel(); ++i) { + if (i % 4 == 0) { + boxes_data[i] = + std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f); + } else if (i % 4 == 1) { + boxes_data[i] = + std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f); + } else if (i % 4 == 2) { + boxes_data[i] = + std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f); + } else { + boxes_data[i] = + std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f); + } + } +} + +template +void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes, + float min_size, const Tensor &im_info, Tensor *keep) { + const T *im_info_data = im_info.data(); + T *boxes_data = boxes->mutable_data(ctx.GetPlace()); + min_size *= im_info_data[2]; + keep->Resize({boxes->dims()[0], 1}); + int *keep_data = keep->mutable_data(ctx.GetPlace()); + + int keep_len = 0; + for (int i = 0; i < boxes->dims()[0]; ++i) { + T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1; + T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1; + T x_ctr = boxes_data[4 * i] + ws / 2; + T y_ctr = boxes_data[4 * i + 1] + hs / 2; + if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] && + y_ctr <= im_info_data[0]) { + keep_data[keep_len++] = i; + } + } + keep->Resize({keep_len}); +} + +bool SortScorePairDescend(const std::pair &pair1, + const std::pair &pair2) { + return pair1.first > pair2.first; +} + +template +void GetMaxScoreIndex(const std::vector &scores, + std::vector> *sorted_indices) { + for (size_t i = 0; i < scores.size(); ++i) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), sorted_indices->end(), + SortScorePairDescend); +} + +template +T BBoxArea(const T *box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +T JaccardOverlap(const T *box1, const T *box2, const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + const T inter_w = inter_xmax - inter_xmin; + const T inter_h = inter_ymax - inter_ymin; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores, + const T nms_threshold, const float eta) { + PADDLE_ENFORCE_NOT_NULL(bbox); + int64_t num_boxes = bbox->dims()[0]; + // 4: [xmin ymin xmax ymax] + int64_t box_size = bbox->dims()[1]; + + std::vector scores_data(num_boxes); + std::copy_n(scores->data(), num_boxes, scores_data.begin()); + std::vector> sorted_indices; + GetMaxScoreIndex(scores_data, &sorted_indices); + + std::vector selected_indices; + int selected_num = 0; + T adaptive_threshold = nms_threshold; + const T *bbox_data = bbox->data(); + bool flag; + while (sorted_indices.size() != 0) { + int idx = sorted_indices.front().second; + flag = true; + for (size_t k = 0; k < selected_indices.size(); ++k) { + if (flag) { + const int kept_idx = selected_indices[k]; + T overlap = JaccardOverlap(bbox_data + idx * box_size, + bbox_data + kept_idx * box_size, false); + flag = (overlap <= adaptive_threshold); + } else { + break; + } + } + if (flag) { + selected_indices.push_back(idx); + selected_num++; + } + sorted_indices.erase(sorted_indices.begin()); + if (flag && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } + Tensor keep_nms; + keep_nms.Resize({selected_num}); + int *keep_data = keep_nms.mutable_data(ctx.GetPlace()); + for (int i = 0; i < selected_num; ++i) { + keep_data[i] = selected_indices[i]; + } + + return keep_nms; +} + +template +class GenerateProposalsKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *scores = context.Input("Scores"); + auto *bbox_deltas = context.Input("BboxDeltas"); + auto *im_info = context.Input("ImInfo"); + auto *anchors = context.Input("Anchors"); + auto *variances = context.Input("Variances"); + + auto *rpn_rois = context.Output("RpnRois"); + auto *rpn_roi_probs = context.Output("RpnRoiProbs"); + + int pre_nms_top_n = context.Attr("pre_nms_topN"); + int post_nms_top_n = context.Attr("post_nms_topN"); + float nms_thresh = context.Attr("nms_thresh"); + float min_size = context.Attr("min_size"); + float eta = context.Attr("eta"); + + auto &dev_ctx = context.template device_context(); + + auto scores_dim = scores->dims(); + int64_t num = scores_dim[0]; + int64_t c_score = scores_dim[1]; + int64_t h_score = scores_dim[2]; + int64_t w_score = scores_dim[3]; + + auto bbox_dim = bbox_deltas->dims(); + int64_t c_bbox = bbox_dim[1]; + int64_t h_bbox = bbox_dim[2]; + int64_t w_bbox = bbox_dim[3]; + + rpn_rois->mutable_data({bbox_deltas->numel() / 4, 4}, + context.GetPlace()); + rpn_roi_probs->mutable_data({scores->numel() / 4, 1}, + context.GetPlace()); + + Tensor bbox_deltas_swap, scores_swap; + bbox_deltas_swap.mutable_data({num, h_bbox, w_bbox, c_bbox}, + dev_ctx.GetPlace()); + scores_swap.mutable_data({num, h_score, w_score, c_score}, + dev_ctx.GetPlace()); + + math::Transpose trans; + std::vector axis = {0, 2, 3, 1}; + trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis); + trans(dev_ctx, *scores, &scores_swap, axis); + + framework::LoD lod; + std::vector lod0(1, 0); + Tensor *anchor = const_cast(anchors); + anchor->Resize({anchors->numel() / 4, 4}); + Tensor *var = const_cast(variances); + var->Resize({var->numel() / 4, 4}); + + int64_t num_proposals = 0; + for (int64_t i = 0; i < num; ++i) { + Tensor im_info_slice = im_info->Slice(i, i + 1); + Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1); + Tensor scores_slice = scores_swap.Slice(i, i + 1); + + bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4}); + scores_slice.Resize({h_score * w_score * c_score, 1}); + + std::pair tensor_pair = + ProposalForOneImage(dev_ctx, im_info_slice, *anchor, *var, + bbox_deltas_slice, scores_slice, pre_nms_top_n, + post_nms_top_n, nms_thresh, min_size, eta); + Tensor proposals = tensor_pair.first; + Tensor scores = tensor_pair.second; + + framework::VisitDataType( + framework::ToDataType(rpn_rois->type()), + AppendProposalsFunctor(rpn_rois, 4 * num_proposals, &proposals)); + framework::VisitDataType( + framework::ToDataType(rpn_roi_probs->type()), + AppendProposalsFunctor(rpn_roi_probs, num_proposals, &scores)); + + num_proposals += proposals.dims()[0]; + lod0.emplace_back(num_proposals); + } + + lod.emplace_back(lod0); + rpn_rois->set_lod(lod); + rpn_roi_probs->set_lod(lod); + rpn_rois->Resize({num_proposals, 4}); + rpn_roi_probs->Resize({num_proposals, 1}); + } + + std::pair ProposalForOneImage( + const DeviceContext &ctx, const Tensor &im_info_slice, + const Tensor &anchors, const Tensor &variances, + const Tensor &bbox_deltas_slice, // [M, 4] + const Tensor &scores_slice, // [N, 1] + int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size, + float eta) const { + auto *scores_data = scores_slice.data(); + + // Sort index + Tensor index_t; + index_t.Resize({scores_slice.numel()}); + int *index = index_t.mutable_data(ctx.GetPlace()); + for (int i = 0; i < scores_slice.numel(); ++i) { + index[i] = i; + } + std::function compare = + [scores_data](const int64_t &i, const int64_t &j) { + return scores_data[i] > scores_data[j]; + }; + + if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) { + std::sort(index, index + scores_slice.numel(), compare); + } else { + std::nth_element(index, index + pre_nms_top_n, + index + scores_slice.numel(), compare); + index_t.Resize({pre_nms_top_n}); + } + + Tensor scores_sel, bbox_sel, anchor_sel, var_sel; + scores_sel.mutable_data({index_t.numel(), 1}, ctx.GetPlace()); + bbox_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); + anchor_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); + var_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); + + CPUGather(ctx, scores_slice, index_t, &scores_sel); + CPUGather(ctx, bbox_deltas_slice, index_t, &bbox_sel); + CPUGather(ctx, anchors, index_t, &anchor_sel); + CPUGather(ctx, variances, index_t, &var_sel); + + Tensor proposals; + proposals.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); + BoxCoder(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals); + + ClipTiledBoxes(ctx, im_info_slice, &proposals); + + Tensor keep; + FilterBoxes(ctx, &proposals, min_size, im_info_slice, &keep); + + Tensor scores_filter; + bbox_sel.mutable_data({keep.numel(), 4}, ctx.GetPlace()); + scores_filter.mutable_data({keep.numel(), 1}, ctx.GetPlace()); + CPUGather(ctx, proposals, keep, &bbox_sel); + CPUGather(ctx, scores_sel, keep, &scores_filter); + if (nms_thresh <= 0) { + return std::make_pair(bbox_sel, scores_sel); + } + + Tensor keep_nms = NMS(ctx, &bbox_sel, &scores_filter, nms_thresh, eta); + + if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) { + keep_nms.Resize({post_nms_top_n}); + } + + proposals.mutable_data({keep_nms.numel(), 4}, ctx.GetPlace()); + scores_sel.mutable_data({keep_nms.numel(), 1}, ctx.GetPlace()); + CPUGather(ctx, bbox_sel, keep_nms, &proposals); + CPUGather(ctx, scores_filter, keep_nms, &scores_sel); + + return std::make_pair(proposals, scores_sel); + } +}; + +class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Scores", "The scores of anchors should be foreground."); + AddInput("BboxDeltas", "bbox_deltas."); + AddInput("ImInfo", "Information for image reshape."); + AddInput("Anchors", "All anchors."); + AddInput("Variances", " variances"); + + AddOutput("RpnRois", "Anchors."); + AddOutput("RpnRoiProbs", "Anchors."); + AddAttr("pre_nms_topN", "pre_nms_topN"); + AddAttr("post_nms_topN", "post_nms_topN"); + AddAttr("nms_thresh", "nms_thres"); + AddAttr("min_size", "min size"); + AddAttr("eta", "eta"); + AddComment(R"DOC( +Generate Proposals OP + +This operator proposes rois according to each box with their probability to be a foreground object and +the box can be calculated by anchors. Bbox_deltais and scores are the output of RPN. Final proposals +could be used to train detection net. + +Scores is the probability for each box to be an object. In format of (N, A, H, W) where N is batch size, A is number +of anchors, H and W are height and width of the feature map. +BboxDeltas is the differece between predicted box locatoin and anchor location. In format of (N, 4*A, H, W) + +For generating proposals, this operator transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) and + calculate box locations as proposals candidates. Then clip boxes to image and remove predicted boxes with small area. +Finally, apply nms to get final proposals as output. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp, + ops::GenerateProposalsOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + generate_proposals, + ops::GenerateProposalsKernel); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 72071478845df444ce72ce946787b2d0ce5f0d23..a5bc1fa8f801066a8281a828eb87dbccb4bb0eff 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -39,6 +39,7 @@ __all__ = [ 'detection_map', 'rpn_target_assign', 'anchor_generator', + 'generate_proposals', ] __auto__ = [ @@ -1253,3 +1254,73 @@ def anchor_generator(input, anchor.stop_gradient = True var.stop_gradient = True return anchor, var + + +def generate_proposals(scores, + bbox_deltas, + im_info, + anchors, + variances, + pre_nms_top_n=6000, + post_nms_top_n=1000, + nms_thresh=0.5, + min_size=0.1, + eta=1.0, + name=None): + """ + ** Generate proposal labels Faster-RCNN ** + + This operation proposes RoIs according to each box with their probability to be a foreground object and + the box can be calculated by anchors. Bbox_deltais and scores to be an object are the output of RPN. Final proposals + could be used to train detection net. + + For generating proposals, this operation performs following steps: + + 1. Transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) + 2. Calculate box locations as proposals candidates. + 3. Clip boxes to image + 4. Remove predicted boxes with small area. + 5. Apply NMS to get final proposals as output. + + + Args: + scores(Variable): A 4-D Tensor with shape [N, A, H, W] represents the probability for each box to be an object. + N is batch size, A is number of anchors, H and W are height and width of the feature map. + bbox_deltas(Variable): A 4-D Tensor with shape [N, 4*A, H, W] represents the differece between predicted box locatoin and anchor location. + im_info(Variable): A 2-D Tensor with shape [N, 3] represents origin image information for N batch. Info contains height, width and scale + between origin image size and the size of feature map. + anchors(Variable): A 4-D Tensor represents the anchors with a layout of [H, W, A, 4]. H and W are height and width of the feature map, + num_anchors is the box count of each position. Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized. + variances(Variable): The expanded variances of anchors with a layout of [H, W, num_priors, 4]. Each variance is in (xcenter, ycenter, w, h) format. + pre_nms_top_n(float): Number of total bboxes to be kept per image before NMS. 6000 by default. + post_nms_top_n(float): Number of total bboxes to be kept per image after NMS. 1000 by default. + nms_thresh(float): Threshold in NMS, 0.5 by default. + min_size(float): Remove predicted boxes with either height or width < min_size. 0.1 by default. + eta(float): Apply in adaptive NMS, if adaptive threshold > 0.5, adaptive_threshold = adaptive_threshold * eta in each iteration. + """ + helper = LayerHelper('generate_proposals', **locals()) + + rpn_rois = helper.create_tmp_variable(dtype=bbox_deltas.dtype) + rpn_roi_probs = helper.create_tmp_variable(dtype=scores.dtype) + helper.append_op( + type="generate_proposals", + inputs={ + 'Scores': scores, + 'BboxDeltas': bbox_deltas, + 'ImInfo': im_info, + 'Anchors': anchors, + 'Variances': variances + }, + attrs={ + 'pre_nms_topN': pre_nms_top_n, + 'post_nms_topN': post_nms_top_n, + 'nms_thresh': nms_thresh, + 'min_size': min_size, + 'eta': eta + }, + outputs={'RpnRois': rpn_rois, + 'RpnRoiProbs': rpn_roi_probs}) + rpn_rois.stop_gradient = True + rpn_roi_probs.stop_gradient = True + + return rpn_rois, rpn_roi_probs diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 1467e72caac26a3ea2a0c770d665141988696630..b71b440d3c9155e63220f5f2d7e849e8332bfb16 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -201,5 +201,44 @@ class TestDetectionMAP(unittest.TestCase): print(str(program)) +class TestGenerateProposals(unittest.TestCase): + def test_generate_proposals(self): + data_shape = [20, 64, 64] + images = fluid.layers.data( + name='images', shape=data_shape, dtype='float32') + im_info = fluid.layers.data( + name='im_info', shape=[1, 3], dtype='float32') + anchors, variances = fluid.layers.anchor_generator( + name='anchor_generator', + input=images, + anchor_sizes=[32, 64], + aspect_ratios=[1.0], + variance=[0.1, 0.1, 0.2, 0.2], + stride=[16.0, 16.0], + offset=0.5) + num_anchors = anchors.shape[2] + scores = fluid.layers.data( + name='scores', shape=[1, num_anchors, 8, 8], dtype='float32') + bbox_deltas = fluid.layers.data( + name='bbox_deltas', + shape=[1, num_anchors * 4, 8, 8], + dtype='float32') + rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals( + name='generate_proposals', + scores=scores, + bbox_deltas=bbox_deltas, + im_info=im_info, + anchors=anchors, + variances=variances, + pre_nms_top_n=6000, + post_nms_top_n=1000, + nms_thresh=0.5, + min_size=0.1, + eta=1.0) + self.assertIsNotNone(rpn_rois) + self.assertIsNotNone(rpn_roi_probs) + print(rpn_rois.shape) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposals.py b/python/paddle/fluid/tests/unittests/test_generate_proposals.py new file mode 100644 index 0000000000000000000000000000000000000000..3fbd2ce95a4f22b91cd4955f914e12f422b0ee83 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_generate_proposals.py @@ -0,0 +1,320 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://w_idxw.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +import math +import paddle.fluid as fluid +from op_test import OpTest +from test_multiclass_nms_op import nms +from test_anchor_generator_op import anchor_generator_in_python +import copy + + +def generate_proposals_in_python(scores, bbox_deltas, im_info, anchors, + variances, pre_nms_topN, post_nms_topN, + nms_thresh, min_size, eta): + all_anchors = anchors.reshape(-1, 4) + rois = np.empty((0, 5), dtype=np.float32) + roi_probs = np.empty((0, 1), dtype=np.float32) + + rpn_rois = [] + rpn_roi_probs = [] + lod = [] + num_images = scores.shape[0] + for img_idx in range(num_images): + img_i_boxes, img_i_probs = proposal_for_one_image( + im_info[img_idx, :], all_anchors, variances, + bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :], + pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta) + lod.append(img_i_probs.shape[0]) + rpn_rois.append(img_i_boxes) + rpn_roi_probs.append(img_i_probs) + + return rpn_rois, rpn_roi_probs, lod + + +def proposal_for_one_image(im_info, all_anchors, variances, bbox_deltas, scores, + pre_nms_topN, post_nms_topN, nms_thresh, min_size, + eta): + # Transpose and reshape predicted bbox transformations to get them + # into the same order as the anchors: + # - bbox deltas will be (4 * A, H, W) format from conv output + # - transpose to (H, W, 4 * A) + # - reshape to (H * W * A, 4) where rows are ordered by (H, W, A) + # in slowest to fastest order to match the enumerated anchors + bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape(-1, 4) + all_anchors = all_anchors.reshape(-1, 4) + variances = variances.reshape(-1, 4) + # Same story for the scores: + # - scores are (A, H, W) format from conv output + # - transpose to (H, W, A) + # - reshape to (H * W * A, 1) where rows are ordered by (H, W, A) + # to match the order of anchors and bbox_deltas + scores = scores.transpose((1, 2, 0)).reshape(-1, 1) + + # sort all (proposal, score) pairs by score from highest to lowest + # take top pre_nms_topN (e.g. 6000) + if pre_nms_topN <= 0 or pre_nms_topN >= len(scores): + order = np.argsort(-scores.squeeze()) + else: + # Avoid sorting possibly large arrays; + # First partition to get top K unsorted + # and then sort just thoes + inds = np.argpartition(-scores.squeeze(), pre_nms_topN)[:pre_nms_topN] + order = np.argsort(-scores[inds].squeeze()) + order = inds[order] + scores = scores[order, :] + bbox_deltas = bbox_deltas[order, :] + all_anchors = all_anchors[order, :] + proposals = box_coder(all_anchors, bbox_deltas, variances) + # clip proposals to image (may result in proposals with zero area + # that will be removed in the next step) + proposals = clip_tiled_boxes(proposals, im_info[:2]) + # remove predicted boxes with height or width < min_size + keep = filter_boxes(proposals, min_size, im_info) + proposals = proposals[keep, :] + scores = scores[keep, :] + + # apply loose nms (e.g. threshold = 0.7) + # take post_nms_topN (e.g. 1000) + # return the top proposals + if nms_thresh > 0: + keep = nms(boxes=proposals, + scores=scores, + nms_threshold=nms_thresh, + eta=eta) + if post_nms_topN > 0 and post_nms_topN < len(keep): + keep = keep[:post_nms_topN] + proposals = proposals[keep, :] + scores = scores[keep, :] + + return proposals, scores + + +def box_coder(all_anchors, bbox_deltas, variances): + """ + Decode proposals by anchors and bbox_deltas from RPN + """ + #proposals: xmin, ymin, xmax, ymax + proposals = np.zeros_like(bbox_deltas, dtype=np.float32) + + #anchor_loc: width, height, center_x, center_y + anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32) + + anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0] + anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1] + anchor_loc[:, 2] = (all_anchors[:, 2] + all_anchors[:, 0]) / 2 + anchor_loc[:, 3] = (all_anchors[:, 3] + all_anchors[:, 1]) / 2 + + #predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height + pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32) + if variances is not None: + for i in range(bbox_deltas.shape[0]): + pred_bbox[i, 0] = variances[i, 0] * bbox_deltas[i, 0] * anchor_loc[ + i, 0] + anchor_loc[i, 2] + pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[ + i, 1] + anchor_loc[i, 3] + pred_bbox[i, 2] = math.exp(variances[i, 2] * + bbox_deltas[i, 2]) * anchor_loc[i, 0] + pred_bbox[i, 3] = math.exp(variances[i, 3] * + bbox_deltas[i, 3]) * anchor_loc[i, 1] + else: + for i in range(bbox_deltas.shape[0]): + pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[ + i, 2] + pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[ + i, 3] + pred_bbox[i, 2] = math.exp(bbox_deltas[i, 2]) * anchor_loc[i, 0] + pred_bbox[i, 3] = math.exp(bbox_deltas[i, 3]) * anchor_loc[i, 1] + + proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2 + proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2 + proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2 + proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2 + + return proposals + + +def clip_tiled_boxes(boxes, im_shape): + """Clip boxes to image boundaries. im_shape is [height, width] and boxes + has shape (N, 4 * num_tiled_boxes).""" + assert boxes.shape[1] % 4 == 0, \ + 'boxes.shape[1] is {:d}, but must be divisible by 4.'.format( + boxes.shape[1] + ) + # x1 >= 0 + boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0) + # y1 >= 0 + boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0) + # x2 < im_shape[1] + boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0) + # y2 < im_shape[0] + boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0) + return boxes + + +def filter_boxes(boxes, min_size, im_info): + """Only keep boxes with both sides >= min_size and center within the image. + """ + # Scale min_size to match image scale + min_size *= im_info[2] + ws = boxes[:, 2] - boxes[:, 0] + 1 + hs = boxes[:, 3] - boxes[:, 1] + 1 + x_ctr = boxes[:, 0] + ws / 2. + y_ctr = boxes[:, 1] + hs / 2. + keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_info[1]) & + (y_ctr < im_info[0]))[0] + return keep + + +def iou(box_a, box_b): + """ + Apply intersection-over-union overlap between box_a and box_b + """ + xmin_a = min(box_a[0], box_a[2]) + ymin_a = min(box_a[1], box_a[3]) + xmax_a = max(box_a[0], box_a[2]) + ymax_a = max(box_a[1], box_a[3]) + + xmin_b = min(box_b[0], box_b[2]) + ymin_b = min(box_b[1], box_b[3]) + xmax_b = max(box_b[0], box_b[2]) + ymax_b = max(box_b[1], box_b[3]) + + area_a = (ymax_a - ymin_a + 1) * (xmax_a - xmin_a + 1) + area_b = (ymax_b - ymin_b + 1) * (xmax_b - xmin_b + 1) + if area_a <= 0 and area_b <= 0: + return 0.0 + + xa = max(xmin_a, xmin_b) + ya = max(ymin_a, ymin_b) + xb = min(xmax_a, xmax_b) + yb = min(ymax_a, ymax_b) + + inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0) + + iou_ratio = inter_area / (area_a + area_b - inter_area) + + return iou_ratio + + +def nms(boxes, scores, nms_threshold, eta=1.0): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + nms_threshold: (float) The overlap thresh for suppressing unnecessary + boxes. + eta: (float) The parameter for adaptive NMS. + Return: + The indices of the kept boxes with respect to num_priors. + """ + all_scores = copy.deepcopy(scores) + all_scores = all_scores.flatten() + + sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort') + sorted_scores = all_scores[sorted_indices] + selected_indices = [] + adaptive_threshold = nms_threshold + for i in range(sorted_scores.shape[0]): + idx = sorted_indices[i] + keep = True + for k in range(len(selected_indices)): + if keep: + kept_idx = selected_indices[k] + overlap = iou(boxes[idx], boxes[kept_idx]) + keep = True if overlap <= adaptive_threshold else False + else: + break + if keep: + selected_indices.append(idx) + if keep and eta < 1 and adaptive_threshold > 0.5: + adaptive_threshold *= eta + return selected_indices + + +class TestGenerateProposalsOp(OpTest): + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = { + 'Scores': self.scores, + 'BboxDeltas': self.bbox_deltas, + 'ImInfo': self.im_info.astype(np.float32), + 'Anchors': self.anchors, + 'Variances': self.variances + } + + self.attrs = { + 'pre_nms_topN': self.pre_nms_topN, + 'post_nms_topN': self.post_nms_topN, + 'nms_thresh': self.nms_thresh, + 'min_size': self.min_size, + 'eta': self.eta + } + + print("lod = ", self.lod) + self.outputs = { + 'RpnRois': (self.rpn_rois[0], [self.lod]), + 'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod]) + } + + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "generate_proposals" + self.set_data() + + def init_test_params(self): + self.pre_nms_topN = 12000 # train 12000, test 2000 + self.post_nms_topN = 5000 # train 6000, test 1000 + self.nms_thresh = 0.7 + self.min_size = 3.0 + self.eta = 0.8 + + def init_test_input(self): + batch_size = 1 + input_channels = 20 + layer_h = 16 + layer_w = 16 + input_feat = np.random.random( + (batch_size, input_channels, layer_h, layer_w)).astype('float32') + self.anchors, self.variances = anchor_generator_in_python( + input_feat=input_feat, + anchor_sizes=[16., 32.], + aspect_ratios=[0.5, 1.0], + variances=[1.0, 1.0, 1.0, 1.0], + stride=[16.0, 16.0], + offset=0.5) + self.im_info = np.array([[64., 64., 8.]]) #im_height, im_width, scale + num_anchors = self.anchors.shape[2] + self.scores = np.random.random( + (batch_size, num_anchors, layer_h, layer_w)).astype('float32') + self.bbox_deltas = np.random.random( + (batch_size, num_anchors * 4, layer_h, layer_w)).astype('float32') + + def init_test_output(self): + self.rpn_rois, self.rpn_roi_probs, self.lod = generate_proposals_in_python( + self.scores, self.bbox_deltas, self.im_info, self.anchors, + self.variances, self.pre_nms_topN, self.post_nms_topN, + self.nms_thresh, self.min_size, self.eta) + + +if __name__ == '__main__': + unittest.main()