/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/nms_util.h" #include "paddle/phi/kernels/funcs/gather.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; class GenerateProposalsV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE_EQ( ctx->HasInput("Scores"), true, platform::errors::NotFound("Input(Scores) shouldn't be null.")); PADDLE_ENFORCE_EQ( ctx->HasInput("BboxDeltas"), true, platform::errors::NotFound("Input(BboxDeltas) shouldn't be null.")); PADDLE_ENFORCE_EQ( ctx->HasInput("ImShape"), true, platform::errors::NotFound("Input(ImShape) shouldn't be null.")); PADDLE_ENFORCE_EQ( ctx->HasInput("Anchors"), true, platform::errors::NotFound("Input(Anchors) shouldn't be null.")); PADDLE_ENFORCE_EQ( ctx->HasInput("Variances"), true, platform::errors::NotFound("Input(Variances) shouldn't be null.")); ctx->SetOutputDim("RpnRois", {-1, 4}); ctx->SetOutputDim("RpnRoiProbs", {-1, 1}); if (!ctx->IsRuntime()) { ctx->SetLoDLevel("RpnRois", std::max(ctx->GetLoDLevel("Scores"), 1)); ctx->SetLoDLevel("RpnRoiProbs", std::max(ctx->GetLoDLevel("Scores"), 1)); } } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"), ctx.device_context()); } }; template class GenerateProposalsV2Kernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *scores = context.Input("Scores"); auto *bbox_deltas = context.Input("BboxDeltas"); auto *im_shape = context.Input("ImShape"); auto anchors = GET_DATA_SAFELY(context.Input("Anchors"), "Input", "Anchors", "GenerateProposals"); auto variances = GET_DATA_SAFELY(context.Input("Variances"), "Input", "Variances", "GenerateProposals"); auto *rpn_rois = context.Output("RpnRois"); auto *rpn_roi_probs = context.Output("RpnRoiProbs"); int pre_nms_top_n = context.Attr("pre_nms_topN"); int post_nms_top_n = context.Attr("post_nms_topN"); float nms_thresh = context.Attr("nms_thresh"); float min_size = context.Attr("min_size"); float eta = context.Attr("eta"); bool pixel_offset = context.Attr("pixel_offset"); auto &dev_ctx = context.template device_context(); auto &scores_dim = scores->dims(); int64_t num = scores_dim[0]; int64_t c_score = scores_dim[1]; int64_t h_score = scores_dim[2]; int64_t w_score = scores_dim[3]; auto &bbox_dim = bbox_deltas->dims(); int64_t c_bbox = bbox_dim[1]; int64_t h_bbox = bbox_dim[2]; int64_t w_bbox = bbox_dim[3]; rpn_rois->mutable_data({bbox_deltas->numel() / 4, 4}, context.GetPlace()); rpn_roi_probs->mutable_data({scores->numel(), 1}, context.GetPlace()); Tensor bbox_deltas_swap, scores_swap; bbox_deltas_swap.mutable_data({num, h_bbox, w_bbox, c_bbox}, dev_ctx.GetPlace()); scores_swap.mutable_data({num, h_score, w_score, c_score}, dev_ctx.GetPlace()); phi::funcs::Transpose trans; std::vector axis = {0, 2, 3, 1}; trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis); trans(dev_ctx, *scores, &scores_swap, axis); framework::LoD lod; lod.resize(1); auto &lod0 = lod[0]; lod0.push_back(0); anchors.Resize({anchors.numel() / 4, 4}); variances.Resize({variances.numel() / 4, 4}); std::vector tmp_num; int64_t num_proposals = 0; for (int64_t i = 0; i < num; ++i) { Tensor im_shape_slice = im_shape->Slice(i, i + 1); Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1); Tensor scores_slice = scores_swap.Slice(i, i + 1); bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4}); scores_slice.Resize({h_score * w_score * c_score, 1}); std::pair tensor_pair = ProposalForOneImage( dev_ctx, im_shape_slice, anchors, variances, bbox_deltas_slice, scores_slice, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta, pixel_offset); Tensor &proposals = tensor_pair.first; Tensor &scores = tensor_pair.second; AppendProposals(rpn_rois, 4 * num_proposals, proposals); AppendProposals(rpn_roi_probs, num_proposals, scores); num_proposals += proposals.dims()[0]; lod0.push_back(num_proposals); tmp_num.push_back(proposals.dims()[0]); } if (context.HasOutput("RpnRoisNum")) { auto *rpn_rois_num = context.Output("RpnRoisNum"); rpn_rois_num->mutable_data({num}, context.GetPlace()); int *num_data = rpn_rois_num->data(); for (int i = 0; i < num; i++) { num_data[i] = tmp_num[i]; } rpn_rois_num->Resize({num}); } rpn_rois->set_lod(lod); rpn_roi_probs->set_lod(lod); rpn_rois->Resize({num_proposals, 4}); rpn_roi_probs->Resize({num_proposals, 1}); } std::pair ProposalForOneImage( const platform::CPUDeviceContext &ctx, const Tensor &im_shape_slice, const Tensor &anchors, const Tensor &variances, const Tensor &bbox_deltas_slice, // [M, 4] const Tensor &scores_slice, // [N, 1] int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size, float eta, bool pixel_offset = true) const { auto *scores_data = scores_slice.data(); // Sort index Tensor index_t; index_t.Resize({scores_slice.numel()}); int *index = index_t.mutable_data(ctx.GetPlace()); for (int i = 0; i < scores_slice.numel(); ++i) { index[i] = i; } auto compare = [scores_data](const int64_t &i, const int64_t &j) { return scores_data[i] > scores_data[j]; }; if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) { std::sort(index, index + scores_slice.numel(), compare); } else { std::nth_element(index, index + pre_nms_top_n, index + scores_slice.numel(), compare); index_t.Resize({pre_nms_top_n}); } Tensor scores_sel, bbox_sel, anchor_sel, var_sel; scores_sel.mutable_data({index_t.numel(), 1}, ctx.GetPlace()); bbox_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); anchor_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); var_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); phi::funcs::CPUGather(ctx, scores_slice, index_t, &scores_sel); phi::funcs::CPUGather(ctx, bbox_deltas_slice, index_t, &bbox_sel); phi::funcs::CPUGather(ctx, anchors, index_t, &anchor_sel); phi::funcs::CPUGather(ctx, variances, index_t, &var_sel); Tensor proposals; proposals.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); BoxCoder(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals, pixel_offset); ClipTiledBoxes(ctx, im_shape_slice, proposals, &proposals, false, pixel_offset); Tensor keep; FilterBoxes(ctx, &proposals, min_size, im_shape_slice, false, &keep, pixel_offset); // Handle the case when there is no keep index left if (keep.numel() == 0) { phi::funcs::SetConstant set_zero; bbox_sel.mutable_data({1, 4}, ctx.GetPlace()); set_zero(ctx, &bbox_sel, static_cast(0)); Tensor scores_filter; scores_filter.mutable_data({1, 1}, ctx.GetPlace()); set_zero(ctx, &scores_filter, static_cast(0)); return std::make_pair(bbox_sel, scores_filter); } Tensor scores_filter; bbox_sel.mutable_data({keep.numel(), 4}, ctx.GetPlace()); scores_filter.mutable_data({keep.numel(), 1}, ctx.GetPlace()); phi::funcs::CPUGather(ctx, proposals, keep, &bbox_sel); phi::funcs::CPUGather(ctx, scores_sel, keep, &scores_filter); if (nms_thresh <= 0) { return std::make_pair(bbox_sel, scores_filter); } Tensor keep_nms = NMS(ctx, &bbox_sel, &scores_filter, nms_thresh, eta, pixel_offset); if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) { keep_nms.Resize({post_nms_top_n}); } proposals.mutable_data({keep_nms.numel(), 4}, ctx.GetPlace()); scores_sel.mutable_data({keep_nms.numel(), 1}, ctx.GetPlace()); phi::funcs::CPUGather(ctx, bbox_sel, keep_nms, &proposals); phi::funcs::CPUGather(ctx, scores_filter, keep_nms, &scores_sel); return std::make_pair(proposals, scores_sel); } }; class GenerateProposalsV2OpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Scores", "(Tensor) The scores from conv is in shape (N, A, H, W), " "N is batch size, A is number of anchors, " "H and W are height and width of the feature map"); AddInput("BboxDeltas", "(Tensor) Bounding box deltas from conv is in " "shape (N, 4*A, H, W)."); AddInput("ImShape", "(Tensor) Image shape in shape (N, 2), " "in format (height, width)"); AddInput("Anchors", "(Tensor) Bounding box anchors from anchor_generator_op " "is in shape (A, H, W, 4)."); AddInput("Variances", "(Tensor) Bounding box variances with same shape as `Anchors`."); AddOutput("RpnRois", "(LoDTensor), Output proposals with shape (rois_num, 4)."); AddOutput("RpnRoiProbs", "(LoDTensor) Scores of proposals with shape (rois_num, 1)."); AddOutput("RpnRoisNum", "(Tensor), The number of Rpn RoIs in each image") .AsDispensable(); AddAttr("pre_nms_topN", "Number of top scoring RPN proposals to keep before " "applying NMS."); AddAttr("post_nms_topN", "Number of top scoring RPN proposals to keep after " "applying NMS"); AddAttr("nms_thresh", "NMS threshold used on RPN proposals."); AddAttr("min_size", "Proposal height and width both need to be greater " "than this min_size."); AddAttr("eta", "The parameter for adaptive NMS."); AddAttr("pixel_offset", "(bool, default True),", "If true, im_shape pixel offset is 1.") .SetDefault(true); AddComment(R"DOC( This operator is the second version of generate_proposals op to generate bounding box proposals for Faster RCNN. The proposals are generated for a list of images based on image score 'Scores', bounding box regression result 'BboxDeltas' as well as predefined bounding box shapes 'anchors'. Greedy non-maximum suppression is applied to generate the final bounding boxes. The difference between this version and the first version is that the image scale is no long needed now, so the input requires im_shape instead of im_info. The change aims to unify the input for all kinds of objective detection such as YOLO-v3 and Faster R-CNN. As a result, the min_size represents the size on input image instead of original image which is slightly different to before and will not effect the result. )DOC"); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR( generate_proposals_v2, ops::GenerateProposalsV2Op, ops::GenerateProposalsV2OpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(generate_proposals_v2, ops::GenerateProposalsV2Kernel, ops::GenerateProposalsV2Kernel); REGISTER_OP_VERSION(generate_proposals_v2) .AddCheckpoint( R"ROC(Registe generate_proposals_v2 for adding the attribute of pixel_offset)ROC", paddle::framework::compatible::OpVersionDesc().NewAttr( "pixel_offset", "If true, im_shape pixel offset is 1.", true));