/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; static const double kBBoxClipDefault = std::log(1000.0 / 16.0); static void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) { auto *out_data = dst->data(); auto *to_add_data = src.data(); size_t size_of_t = framework::SizeOfType(src.type()); offset *= size_of_t; std::memcpy( reinterpret_cast(reinterpret_cast(out_data) + offset), to_add_data, src.numel() * size_of_t); } class GenerateProposalsOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Scores"), "Input(Scores) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("BboxDeltas"), "Input(BboxDeltas) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("ImInfo"), "Input(ImInfo) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Anchors"), "Input(Anchors) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Variances"), "Input(Variances) shouldn't be null."); ctx->SetOutputDim("RpnRois", {-1, 4}); ctx->SetOutputDim("RpnRoiProbs", {-1, 1}); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"), ctx.device_context()); } }; template static inline void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, Tensor *bbox_deltas, Tensor *variances, Tensor *proposals) { T *proposals_data = proposals->mutable_data(ctx.GetPlace()); int64_t row = all_anchors->dims()[0]; int64_t len = all_anchors->dims()[1]; auto *bbox_deltas_data = bbox_deltas->data(); auto *anchor_data = all_anchors->data(); const T *variances_data = nullptr; if (variances) { variances_data = variances->data(); } for (int64_t i = 0; i < row; ++i) { T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0; T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0; T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width; T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height; T bbox_center_x = 0, bbox_center_y = 0; T bbox_width = 0, bbox_height = 0; if (variances) { bbox_center_x = variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width + anchor_center_x; bbox_center_y = variances_data[i * len + 1] * bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y; bbox_width = std::exp(std::min(variances_data[i * len + 2] * bbox_deltas_data[i * len + 2], kBBoxClipDefault)) * anchor_width; bbox_height = std::exp(std::min(variances_data[i * len + 3] * bbox_deltas_data[i * len + 3], kBBoxClipDefault)) * anchor_height; } else { bbox_center_x = bbox_deltas_data[i * len] * anchor_width + anchor_center_x; bbox_center_y = bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y; bbox_width = std::exp(std::min(bbox_deltas_data[i * len + 2], kBBoxClipDefault)) * anchor_width; bbox_height = std::exp(std::min(bbox_deltas_data[i * len + 3], kBBoxClipDefault)) * anchor_height; } proposals_data[i * len] = bbox_center_x - bbox_width / 2; proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2; proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1; proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1; } // return proposals; } template static inline void ClipTiledBoxes(const platform::DeviceContext &ctx, const Tensor &im_info, Tensor *boxes) { T *boxes_data = boxes->mutable_data(ctx.GetPlace()); const T *im_info_data = im_info.data(); T zero(0); for (int64_t i = 0; i < boxes->numel(); ++i) { if (i % 4 == 0) { boxes_data[i] = std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero); } else if (i % 4 == 1) { boxes_data[i] = std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero); } else if (i % 4 == 2) { boxes_data[i] = std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero); } else { boxes_data[i] = std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero); } } } template static inline void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes, float min_size, const Tensor &im_info, Tensor *keep) { const T *im_info_data = im_info.data(); T *boxes_data = boxes->mutable_data(ctx.GetPlace()); T im_scale = im_info_data[2]; keep->Resize({boxes->dims()[0]}); min_size = std::max(min_size, 1.0f); int *keep_data = keep->mutable_data(ctx.GetPlace()); int keep_len = 0; for (int i = 0; i < boxes->dims()[0]; ++i) { T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1; T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1; T ws_origin_scale = (boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_scale + 1; T hs_origin_scale = (boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_scale + 1; T x_ctr = boxes_data[4 * i] + ws / 2; T y_ctr = boxes_data[4 * i + 1] + hs / 2; if (ws_origin_scale >= min_size && hs_origin_scale >= min_size && x_ctr <= im_info_data[1] && y_ctr <= im_info_data[0]) { keep_data[keep_len++] = i; } } keep->Resize({keep_len}); } template static inline std::vector> GetSortedScoreIndex( const std::vector &scores) { std::vector> sorted_indices; sorted_indices.reserve(scores.size()); for (size_t i = 0; i < scores.size(); ++i) { sorted_indices.emplace_back(scores[i], i); } // Sort the score pair according to the scores in descending order std::stable_sort(sorted_indices.begin(), sorted_indices.end(), [](const std::pair &a, const std::pair &b) { return a.first < b.first; }); return sorted_indices; } template static inline T BBoxArea(const T *box, bool normalized) { if (box[2] < box[0] || box[3] < box[1]) { // If coordinate values are is invalid // (e.g. xmax < xmin or ymax < ymin), return 0. return static_cast(0.); } else { const T w = box[2] - box[0]; const T h = box[3] - box[1]; if (normalized) { return w * h; } else { // If coordinate values are not within range [0, 1]. return (w + 1) * (h + 1); } } } template static inline T JaccardOverlap(const T *box1, const T *box2, bool normalized) { if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || box2[3] < box1[1]) { return static_cast(0.); } else { const T inter_xmin = std::max(box1[0], box2[0]); const T inter_ymin = std::max(box1[1], box2[1]); const T inter_xmax = std::min(box1[2], box2[2]); const T inter_ymax = std::min(box1[3], box2[3]); const T inter_w = std::max(T(0), inter_xmax - inter_xmin + 1); const T inter_h = std::max(T(0), inter_ymax - inter_ymin + 1); const T inter_area = inter_w * inter_h; const T bbox1_area = BBoxArea(box1, normalized); const T bbox2_area = BBoxArea(box2, normalized); return inter_area / (bbox1_area + bbox2_area - inter_area); } } template static inline Tensor VectorToTensor(const std::vector &selected_indices, int selected_num) { Tensor keep_nms; keep_nms.Resize({selected_num}); auto *keep_data = keep_nms.mutable_data(platform::CPUPlace()); for (int i = 0; i < selected_num; ++i) { keep_data[i] = selected_indices[i]; } return keep_nms; } template static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores, T nms_threshold, float eta) { PADDLE_ENFORCE_NOT_NULL(bbox); int64_t num_boxes = bbox->dims()[0]; // 4: [xmin ymin xmax ymax] int64_t box_size = bbox->dims()[1]; std::vector scores_data(num_boxes); std::copy_n(scores->data(), num_boxes, scores_data.begin()); std::vector> sorted_indices = GetSortedScoreIndex(scores_data); std::vector selected_indices; int selected_num = 0; T adaptive_threshold = nms_threshold; const T *bbox_data = bbox->data(); while (sorted_indices.size() != 0) { int idx = sorted_indices.back().second; bool flag = true; for (int kept_idx : selected_indices) { if (flag) { T overlap = JaccardOverlap(bbox_data + idx * box_size, bbox_data + kept_idx * box_size, false); flag = (overlap <= adaptive_threshold); } else { break; } } if (flag) { selected_indices.push_back(idx); ++selected_num; } sorted_indices.erase(sorted_indices.end() - 1); if (flag && eta < 1 && adaptive_threshold > 0.5) { adaptive_threshold *= eta; } } return VectorToTensor(selected_indices, selected_num); } template class GenerateProposalsKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *scores = context.Input("Scores"); auto *bbox_deltas = context.Input("BboxDeltas"); auto *im_info = context.Input("ImInfo"); auto anchors = GET_DATA_SAFELY(context.Input("Anchors"), "Input", "Anchors", "GenerateProposals"); auto variances = GET_DATA_SAFELY(context.Input("Variances"), "Input", "Variances", "GenerateProposals"); auto *rpn_rois = context.Output("RpnRois"); auto *rpn_roi_probs = context.Output("RpnRoiProbs"); int pre_nms_top_n = context.Attr("pre_nms_topN"); int post_nms_top_n = context.Attr("post_nms_topN"); float nms_thresh = context.Attr("nms_thresh"); float min_size = context.Attr("min_size"); float eta = context.Attr("eta"); auto &dev_ctx = context.template device_context(); auto &scores_dim = scores->dims(); int64_t num = scores_dim[0]; int64_t c_score = scores_dim[1]; int64_t h_score = scores_dim[2]; int64_t w_score = scores_dim[3]; auto &bbox_dim = bbox_deltas->dims(); int64_t c_bbox = bbox_dim[1]; int64_t h_bbox = bbox_dim[2]; int64_t w_bbox = bbox_dim[3]; rpn_rois->mutable_data({bbox_deltas->numel() / 4, 4}, context.GetPlace()); rpn_roi_probs->mutable_data({scores->numel(), 1}, context.GetPlace()); Tensor bbox_deltas_swap, scores_swap; bbox_deltas_swap.mutable_data({num, h_bbox, w_bbox, c_bbox}, dev_ctx.GetPlace()); scores_swap.mutable_data({num, h_score, w_score, c_score}, dev_ctx.GetPlace()); math::Transpose trans; std::vector axis = {0, 2, 3, 1}; trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis); trans(dev_ctx, *scores, &scores_swap, axis); framework::LoD lod; lod.resize(1); auto &lod0 = lod[0]; lod0.push_back(0); anchors.Resize({anchors.numel() / 4, 4}); variances.Resize({variances.numel() / 4, 4}); int64_t num_proposals = 0; for (int64_t i = 0; i < num; ++i) { Tensor im_info_slice = im_info->Slice(i, i + 1); Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1); Tensor scores_slice = scores_swap.Slice(i, i + 1); bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4}); scores_slice.Resize({h_score * w_score * c_score, 1}); std::pair tensor_pair = ProposalForOneImage(dev_ctx, im_info_slice, anchors, variances, bbox_deltas_slice, scores_slice, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta); Tensor &proposals = tensor_pair.first; Tensor &scores = tensor_pair.second; AppendProposals(rpn_rois, 4 * num_proposals, proposals); AppendProposals(rpn_roi_probs, num_proposals, scores); num_proposals += proposals.dims()[0]; lod0.push_back(num_proposals); } rpn_rois->set_lod(lod); rpn_roi_probs->set_lod(lod); rpn_rois->Resize({num_proposals, 4}); rpn_roi_probs->Resize({num_proposals, 1}); } std::pair ProposalForOneImage( const platform::CPUDeviceContext &ctx, const Tensor &im_info_slice, const Tensor &anchors, const Tensor &variances, const Tensor &bbox_deltas_slice, // [M, 4] const Tensor &scores_slice, // [N, 1] int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size, float eta) const { auto *scores_data = scores_slice.data(); // Sort index Tensor index_t; index_t.Resize({scores_slice.numel()}); int *index = index_t.mutable_data(ctx.GetPlace()); for (int i = 0; i < scores_slice.numel(); ++i) { index[i] = i; } auto compare = [scores_data](const int64_t &i, const int64_t &j) { return scores_data[i] > scores_data[j]; }; if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) { std::sort(index, index + scores_slice.numel(), compare); } else { std::nth_element(index, index + pre_nms_top_n, index + scores_slice.numel(), compare); index_t.Resize({pre_nms_top_n}); } Tensor scores_sel, bbox_sel, anchor_sel, var_sel; scores_sel.mutable_data({index_t.numel(), 1}, ctx.GetPlace()); bbox_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); anchor_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); var_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); CPUGather(ctx, scores_slice, index_t, &scores_sel); CPUGather(ctx, bbox_deltas_slice, index_t, &bbox_sel); CPUGather(ctx, anchors, index_t, &anchor_sel); CPUGather(ctx, variances, index_t, &var_sel); Tensor proposals; proposals.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); BoxCoder(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals); ClipTiledBoxes(ctx, im_info_slice, &proposals); Tensor keep; FilterBoxes(ctx, &proposals, min_size, im_info_slice, &keep); Tensor scores_filter; bbox_sel.mutable_data({keep.numel(), 4}, ctx.GetPlace()); scores_filter.mutable_data({keep.numel(), 1}, ctx.GetPlace()); CPUGather(ctx, proposals, keep, &bbox_sel); CPUGather(ctx, scores_sel, keep, &scores_filter); if (nms_thresh <= 0) { return std::make_pair(bbox_sel, scores_filter); } Tensor keep_nms = NMS(ctx, &bbox_sel, &scores_filter, nms_thresh, eta); if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) { keep_nms.Resize({post_nms_top_n}); } proposals.mutable_data({keep_nms.numel(), 4}, ctx.GetPlace()); scores_sel.mutable_data({keep_nms.numel(), 1}, ctx.GetPlace()); CPUGather(ctx, bbox_sel, keep_nms, &proposals); CPUGather(ctx, scores_filter, keep_nms, &scores_sel); return std::make_pair(proposals, scores_sel); } }; class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Scores", "(Tensor) The scores from conv is in shape (N, A, H, W), " "N is batch size, A is number of anchors, " "H and W are height and width of the feature map"); AddInput("BboxDeltas", "(Tensor) Bounding box deltas from conv is in " "shape (N, 4*A, H, W)."); AddInput("ImInfo", "(Tensor) Information for image reshape is in shape (N, 3), " "in format (height, width, scale)"); AddInput("Anchors", "(Tensor) Bounding box anchors from anchor_generator_op " "is in shape (A, H, W, 4)."); AddInput("Variances", "(Tensor) Bounding box variances with same shape as `Anchors`."); AddOutput("RpnRois", "(LoDTensor), Output proposals with shape (rois_num, 4)."); AddOutput("RpnRoiProbs", "(LoDTensor) Scores of proposals with shape (rois_num, 1)."); AddAttr("pre_nms_topN", "Number of top scoring RPN proposals to keep before " "applying NMS."); AddAttr("post_nms_topN", "Number of top scoring RPN proposals to keep after " "applying NMS"); AddAttr("nms_thresh", "NMS threshold used on RPN proposals."); AddAttr("min_size", "Proposal height and width both need to be greater " "than this min_size."); AddAttr("eta", "The parameter for adaptive NMS."); AddComment(R"DOC( This operator Generate bounding box proposals for Faster RCNN. The propoasls are generated for a list of images based on image score 'Scores', bounding box regression result 'BboxDeltas' as well as predefined bounding box shapes 'anchors'. Greedy non-maximum suppression is applied to generate the final bounding boxes. )DOC"); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR( generate_proposals, ops::GenerateProposalsOp, ops::GenerateProposalsOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel, ops::GenerateProposalsKernel);