diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 66e97caa7059144d47a7825aabafc8b231a2145a..29cde357ca4aad538dff41570c0777a54098cf80 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -360,6 +360,7 @@ paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0')) paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d')) paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5')) +paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace')) paddle.fluid.layers.accuracy (ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)), ('document', '9808534c12c5e739a10f73ebb0b4eafd')) paddle.fluid.layers.auc (ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)), ('document', 'e0e95334fce92d16c2d9db6e7caffc47')) paddle.fluid.layers.exponential_decay (ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)), ('document', '98a5050bee8522fcea81aa795adaba51')) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 94a2016aa53212c3ae5af6d86cccb117855cc3b4..2d655c3e3fcda697069cf83efc0f6ab1e040cfc4 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -39,9 +39,11 @@ detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc bo if(WITH_GPU) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS memory cub) + detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS memory cub) else() detection_library(generate_proposals_op SRCS generate_proposals_op.cc) detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc) + detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc) endif() detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu) diff --git a/paddle/fluid/operators/detection/bbox_util.h b/paddle/fluid/operators/detection/bbox_util.h index d4cf9a326cc5000e8e75322b59aefc3fb18e86b6..afc39c1db9fba8bf01a78ade83af1037a83d8d9d 100644 --- a/paddle/fluid/operators/detection/bbox_util.h +++ b/paddle/fluid/operators/detection/bbox_util.h @@ -22,10 +22,10 @@ namespace paddle { namespace operators { struct RangeInitFunctor { - int start_; - int delta_; - int* out_; - HOSTDEVICE void operator()(size_t i) { out_[i] = start_ + i * delta_; } + int start; + int delta; + int* out; + HOSTDEVICE void operator()(size_t i) { out[i] = start + i * delta; } }; template diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0603072835e8f146e5bb006d5759220900a29e56 --- /dev/null +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License.*/ + +#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +class CollectFpnProposalsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInputs("MultiLevelRois"), + "Inputs(MultiLevelRois) shouldn't be null"); + PADDLE_ENFORCE(context->HasInputs("MultiLevelScores"), + "Inputs(MultiLevelScores) shouldn't be null"); + PADDLE_ENFORCE(context->HasOutput("FpnRois"), + "Outputs(MultiFpnRois) of DistributeOp should not be null"); + auto roi_dims = context->GetInputsDim("MultiLevelRois"); + auto score_dims = context->GetInputsDim("MultiLevelScores"); + auto post_nms_topN = context->Attrs().Get("post_nms_topN"); + std::vector out_dims; + for (auto &roi_dim : roi_dims) { + PADDLE_ENFORCE_EQ(roi_dim[1], 4, + "Second dimension of Input(MultiLevelRois) must be 4"); + } + for (auto &score_dim : score_dims) { + PADDLE_ENFORCE_EQ( + score_dim[1], 1, + "Second dimension of Input(MultiLevelScores) must be 1"); + } + context->SetOutputDim("FpnRois", {post_nms_topN, 4}); + if (!context->IsRuntime()) { // Runtime LoD infershape will be computed + // in Kernel. + context->ShareLoD("MultiLevelRois", "FpnRois"); + } + if (context->IsRuntime()) { + std::vector roi_inputs = + context->GetInputVarPtrs("MultiLevelRois"); + std::vector score_inputs = + context->GetInputVarPtrs("MultiLevelScores"); + for (size_t i = 0; i < roi_inputs.size(); ++i) { + framework::Variable *roi_var = + boost::get(roi_inputs[i]); + framework::Variable *score_var = + boost::get(score_inputs[i]); + auto &roi_lod = roi_var->Get().lod(); + auto &score_lod = score_var->Get().lod(); + PADDLE_ENFORCE_EQ(roi_lod, score_lod, + "Inputs(MultiLevelRois) and Inputs(MultiLevelScores) " + "should have same lod."); + } + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto data_type = + framework::GetDataTypeOfVar(ctx.MultiInputVar("MultiLevelRois")[0]); + return framework::OpKernelType(data_type, ctx.GetPlace()); + } +}; + +class CollectFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("MultiLevelRois", + "(LoDTensor) Multiple roi LoDTensors from each level in shape " + "(N, 4), N is the number of RoIs") + .AsDuplicable(); + AddInput("MultiLevelScores", + "(LoDTensor) Multiple score LoDTensors from each level in shape" + " (N, 1), N is the number of RoIs.") + .AsDuplicable(); + AddOutput("FpnRois", "(LoDTensor) All selected RoIs with highest scores"); + AddAttr("post_nms_topN", + "Select post_nms_topN RoIs from" + " all images and all fpn layers"); + AddComment(R"DOC( +This operator concats all proposals from different images + and different FPN levels. Then sort all of those proposals +by objectness confidence. Select the post_nms_topN RoIs in + total. Finally, re-sort the RoIs in the order of batch index. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(collect_fpn_proposals, ops::CollectFpnProposalsOp, + ops::CollectFpnProposalsOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(collect_fpn_proposals, + ops::CollectFpnProposalsOpKernel, + ops::CollectFpnProposalsOpKernel); diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8fd76e5406b1de79e0cfd738f969fa27c40ced0f --- /dev/null +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu @@ -0,0 +1,211 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "cub/cub.cuh" +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/detection/bbox_util.h" +#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h" +#include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +static constexpr int kNumCUDAThreads = 64; +static constexpr int kNumMaxinumNumBlocks = 4096; + +const int kBBoxSize = 4; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +static __global__ void GetLengthLoD(const int nthreads, const int* batch_ids, + int* length_lod) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (nthreads); + i += blockDim.x * gridDim.x) { + platform::CudaAtomicAdd(length_lod + batch_ids[i], 1); + } +} + +template +class GPUCollectFpnProposalsOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto roi_ins = ctx.MultiInput("MultiLevelRois"); + const auto score_ins = ctx.MultiInput("MultiLevelScores"); + auto fpn_rois = ctx.Output("FpnRois"); + auto& dev_ctx = ctx.template device_context(); + + const int post_nms_topN = ctx.Attr("post_nms_topN"); + + // concat inputs along axis = 0 + int roi_offset = 0; + int score_offset = 0; + int total_roi_num = 0; + for (size_t i = 0; i < roi_ins.size(); ++i) { + total_roi_num += roi_ins[i]->dims()[0]; + } + + int real_post_num = min(post_nms_topN, total_roi_num); + fpn_rois->mutable_data({real_post_num, kBBoxSize}, dev_ctx.GetPlace()); + Tensor concat_rois; + Tensor concat_scores; + T* concat_rois_data = concat_rois.mutable_data( + {total_roi_num, kBBoxSize}, dev_ctx.GetPlace()); + T* concat_scores_data = + concat_scores.mutable_data({total_roi_num, 1}, dev_ctx.GetPlace()); + Tensor roi_batch_id_list; + roi_batch_id_list.Resize({total_roi_num}); + int* roi_batch_id_data = + roi_batch_id_list.mutable_data(platform::CPUPlace()); + int index = 0; + int lod_size; + auto place = boost::get(dev_ctx.GetPlace()); + + for (size_t i = 0; i < roi_ins.size(); ++i) { + auto roi_in = roi_ins[i]; + auto score_in = score_ins[i]; + auto roi_lod = roi_in->lod().back(); + lod_size = roi_lod.size() - 1; + for (size_t n = 0; n < lod_size; ++n) { + for (size_t j = roi_lod[n]; j < roi_lod[n + 1]; ++j) { + roi_batch_id_data[index++] = n; + } + } + + memory::Copy(place, concat_rois_data + roi_offset, place, + roi_in->data(), roi_in->numel() * sizeof(T), + dev_ctx.stream()); + memory::Copy(place, concat_scores_data + score_offset, place, + score_in->data(), score_in->numel() * sizeof(T), + dev_ctx.stream()); + roi_offset += roi_in->numel(); + score_offset += score_in->numel(); + } + + // copy batch id list to GPU + Tensor roi_batch_id_list_gpu; + framework::TensorCopy(roi_batch_id_list, dev_ctx.GetPlace(), + &roi_batch_id_list_gpu); + + Tensor index_in_t; + int* idx_in = + index_in_t.mutable_data({total_roi_num}, dev_ctx.GetPlace()); + platform::ForRange for_range_total( + dev_ctx, total_roi_num); + for_range_total(RangeInitFunctor{0, 1, idx_in}); + + Tensor keys_out_t; + T* keys_out = + keys_out_t.mutable_data({total_roi_num}, dev_ctx.GetPlace()); + Tensor index_out_t; + int* idx_out = + index_out_t.mutable_data({total_roi_num}, dev_ctx.GetPlace()); + + // Determine temporary device storage requirements + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, concat_scores.data(), keys_out, idx_in, + idx_out, total_roi_num); + // Allocate temporary storage + auto d_temp_storage = memory::Alloc(place, temp_storage_bytes, + memory::Allocator::kScratchpad); + + // Run sorting operation + // sort score to get corresponding index + cub::DeviceRadixSort::SortPairsDescending( + d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data(), + keys_out, idx_in, idx_out, total_roi_num); + index_out_t.Resize({real_post_num}); + Tensor sorted_rois; + sorted_rois.mutable_data({real_post_num, kBBoxSize}, dev_ctx.GetPlace()); + Tensor sorted_batch_id; + sorted_batch_id.mutable_data({real_post_num}, dev_ctx.GetPlace()); + GPUGather(dev_ctx, concat_rois, index_out_t, &sorted_rois); + GPUGather(dev_ctx, roi_batch_id_list_gpu, index_out_t, + &sorted_batch_id); + + Tensor batch_index_t; + int* batch_idx_in = + batch_index_t.mutable_data({real_post_num}, dev_ctx.GetPlace()); + platform::ForRange for_range_post( + dev_ctx, real_post_num); + for_range_post(RangeInitFunctor{0, 1, batch_idx_in}); + + Tensor out_id_t; + int* out_id_data = + out_id_t.mutable_data({real_post_num}, dev_ctx.GetPlace()); + // Determine temporary device storage requirements + temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs( + nullptr, temp_storage_bytes, sorted_batch_id.data(), out_id_data, + batch_idx_in, index_out_t.data(), real_post_num); + // Allocate temporary storage + d_temp_storage = memory::Alloc(place, temp_storage_bytes, + memory::Allocator::kScratchpad); + + // Run sorting operation + // sort batch_id to get corresponding index + cub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data(), + out_id_data, batch_idx_in, index_out_t.data(), real_post_num); + + GPUGather(dev_ctx, sorted_rois, index_out_t, fpn_rois); + + Tensor length_lod; + int* length_lod_data = + length_lod.mutable_data({lod_size}, dev_ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(dev_ctx, &length_lod, static_cast(0)); + + int blocks = NumBlocks(real_post_num); + int threads = kNumCUDAThreads; + + // get length-based lod by batch ids + GetLengthLoD<<>>(real_post_num, out_id_data, + length_lod_data); + std::vector length_lod_cpu(lod_size); + memory::Copy(platform::CPUPlace(), length_lod_cpu.data(), place, + length_lod_data, sizeof(int) * lod_size, dev_ctx.stream()); + dev_ctx.Wait(); + + std::vector offset(1, 0); + for (int i = 0; i < lod_size; ++i) { + offset.emplace_back(offset.back() + length_lod_cpu[i]); + } + + framework::LoD lod; + lod.emplace_back(offset); + fpn_rois->set_lod(lod); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + collect_fpn_proposals, + ops::GPUCollectFpnProposalsOpKernel, + ops::GPUCollectFpnProposalsOpKernel); diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.h b/paddle/fluid/operators/detection/collect_fpn_proposals_op.h new file mode 100644 index 0000000000000000000000000000000000000000..268f7e2160f59c4f1780b1c0968b1e886d27ed1d --- /dev/null +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.h @@ -0,0 +1,149 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License.*/ + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/gather.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +const int kBoxDim = 4; + +template +struct ScoreWithID { + T score; + int batch_id; + int index; + int level; + ScoreWithID() { + batch_id = -1; + index = -1; + level = -1; + } + ScoreWithID(T score_, int batch_id_, int index_, int level_) { + score = score_; + batch_id = batch_id_; + index = index_; + level = level_; + } +}; +template +static inline bool CompareByScore(ScoreWithID a, ScoreWithID b) { + return a.score >= b.score; +} + +template +static inline bool CompareByBatchid(ScoreWithID a, ScoreWithID b) { + return a.batch_id < b.batch_id; +} + +template +class CollectFpnProposalsOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto multi_layer_rois = + context.MultiInput("MultiLevelRois"); + + auto multi_layer_scores = + context.MultiInput("MultiLevelScores"); + + auto* fpn_rois = context.Output("FpnRois"); + + int post_nms_topN = context.Attr("post_nms_topN"); + + PADDLE_ENFORCE_GE(post_nms_topN, 0UL, + "The parameter post_nms_topN must be a positive integer"); + + // assert that the length of Rois and scores are same + PADDLE_ENFORCE(multi_layer_rois.size() == multi_layer_scores.size(), + "DistributeFpnProposalsOp need 1 level of LoD"); + // Check if the lod information of two LoDTensor is same + const int num_fpn_level = multi_layer_rois.size(); + std::vector integral_of_all_rois(num_fpn_level + 1, 0); + for (int i = 0; i < num_fpn_level; ++i) { + auto cur_rois_lod = multi_layer_rois[i]->lod().back(); + integral_of_all_rois[i + 1] = + integral_of_all_rois[i] + cur_rois_lod[cur_rois_lod.size() - 1]; + } + + // concatenate all fpn rois scores into a list + // create a vector to store all scores + std::vector> scores_of_all_rois( + integral_of_all_rois[num_fpn_level], ScoreWithID()); + for (int i = 0; i < num_fpn_level; ++i) { + const T* cur_level_scores = multi_layer_scores[i]->data(); + int cur_level_num = integral_of_all_rois[i + 1] - integral_of_all_rois[i]; + auto cur_scores_lod = multi_layer_scores[i]->lod().back(); + int cur_batch_id = 0; + for (int j = 0; j < cur_level_num; ++j) { + if (j >= cur_scores_lod[cur_batch_id + 1]) { + cur_batch_id++; + } + int cur_index = j + integral_of_all_rois[i]; + scores_of_all_rois[cur_index].score = cur_level_scores[j]; + scores_of_all_rois[cur_index].index = j; + scores_of_all_rois[cur_index].level = i; + scores_of_all_rois[cur_index].batch_id = cur_batch_id; + } + } + // keep top post_nms_topN rois + // sort the rois by the score + if (post_nms_topN > integral_of_all_rois[num_fpn_level]) { + post_nms_topN = integral_of_all_rois[num_fpn_level]; + } + std::stable_sort(scores_of_all_rois.begin(), scores_of_all_rois.end(), + CompareByScore); + scores_of_all_rois.resize(post_nms_topN); + // sort by batch id + std::stable_sort(scores_of_all_rois.begin(), scores_of_all_rois.end(), + CompareByBatchid); + // create a pointer array + std::vector multi_fpn_rois_data(num_fpn_level); + for (int i = 0; i < num_fpn_level; ++i) { + multi_fpn_rois_data[i] = multi_layer_rois[i]->data(); + } + // initialize the outputs + fpn_rois->mutable_data({post_nms_topN, kBoxDim}, context.GetPlace()); + T* fpn_rois_data = fpn_rois->data(); + std::vector lod0(1, 0); + int cur_batch_id = 0; + for (int i = 0; i < post_nms_topN; ++i) { + int cur_fpn_level = scores_of_all_rois[i].level; + int cur_level_index = scores_of_all_rois[i].index; + memcpy(fpn_rois_data, + multi_fpn_rois_data[cur_fpn_level] + cur_level_index * kBoxDim, + kBoxDim * sizeof(T)); + fpn_rois_data += kBoxDim; + if (scores_of_all_rois[i].batch_id != cur_batch_id) { + cur_batch_id = scores_of_all_rois[i].batch_id; + lod0.emplace_back(i); + } + } + lod0.emplace_back(post_nms_topN); + framework::LoD lod; + lod.emplace_back(lod0); + fpn_rois->set_lod(lod); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 80895e7a5d82ed0d18c81fe21a263b613a97f19c..93afedbdb0f5a42eb7dc4d1d352878ec60a32d17 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -54,6 +54,7 @@ __all__ = [ 'multiclass_nms', 'distribute_fpn_proposals', 'box_decoder_and_assign', + 'collect_fpn_proposals', ] @@ -2512,3 +2513,68 @@ def box_decoder_and_assign(prior_box, "OutputAssignBox": output_assign_box }) return decoded_box, output_assign_box + + +def collect_fpn_proposals(multi_rois, + multi_scores, + min_level, + max_level, + post_nms_top_n, + name=None): + """ + Concat multi-level RoIs (Region of Interest) and select N RoIs + with respect to multi_scores. This operation performs the following steps: + + 1. Choose num_level RoIs and scores as input: num_level = max_level - min_level + 2. Concat multi-level RoIs and scores + 3. Sort scores and select post_nms_top_n scores + 4. Gather RoIs by selected indices from scores + 5. Re-sort RoIs by corresponding batch_id + + Args: + multi_ros(list): List of RoIs to collect + multi_scores(list): List of scores + min_level(int): The lowest level of FPN layer to collect + max_level(int): The highest level of FPN layer to collect + post_nms_top_n(int): The number of selected RoIs + name(str|None): A name for this layer(optional) + + Returns: + Variable: Output variable of selected RoIs. + + Examples: + .. code-block:: python + + multi_rois = [] + multi_scores = [] + for i in range(4): + multi_rois.append(fluid.layers.data( + name='roi_'+str(i), shape=[4], dtype='float32', lod_level=1)) + for i in range(4): + multi_scores.append(fluid.layers.data( + name='score_'+str(i), shape=[1], dtype='float32', lod_level=1)) + + fpn_rois = fluid.layers.collect_fpn_proposals( + multi_rois=multi_rois, + multi_scores=multi_scores, + min_level=2, + max_level=5, + post_nms_top_n=2000) + """ + + helper = LayerHelper('collect_fpn_proposals', **locals()) + dtype = helper.input_dtype('multi_rois') + num_lvl = max_level - min_level + 1 + input_rois = multi_rois[:num_lvl] + input_scores = multi_scores[:num_lvl] + output_rois = helper.create_variable_for_type_inference(dtype) + output_rois.stop_gradient = True + helper.append_op( + type='collect_fpn_proposals', + inputs={ + 'MultiLevelRois': input_rois, + 'MultiLevelScores': input_scores + }, + outputs={'FpnRois': output_rois}, + attrs={'post_nms_topN': post_nms_top_n}) + return output_rois diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index e1c4c2eca08d4652ecda8e2579d342818c803f4a..434b69c9680e0b8625f3156b6a0ec338aa211d57 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -522,6 +522,32 @@ class TestMulticlassNMS(unittest.TestCase): self.assertIsNotNone(output) +class TestCollectFpnPropsals(unittest.TestCase): + def test_collect_fpn_proposals(self): + program = Program() + with program_guard(program): + multi_bboxes = [] + multi_scores = [] + for i in range(4): + bboxes = layers.data( + name='rois' + str(i), + shape=[10, 4], + dtype='float32', + lod_level=1, + append_batch_size=False) + scores = layers.data( + name='scores' + str(i), + shape=[10, 1], + dtype='float32', + lod_level=1, + append_batch_size=False) + multi_bboxes.append(bboxes) + multi_scores.append(scores) + fpn_rois = layers.collect_fpn_proposals(multi_bboxes, multi_scores, + 2, 5, 10) + self.assertIsNotNone(fpn_rois) + + class TestDistributeFpnProposals(unittest.TestCase): def test_distribute_fpn_proposals(self): program = Program() diff --git a/python/paddle/fluid/tests/unittests/test_collect_fpn_proposals_op.py b/python/paddle/fluid/tests/unittests/test_collect_fpn_proposals_op.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8600f004bb2d4b057bd93415ba29b989d858ce --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collect_fpn_proposals_op.py @@ -0,0 +1,100 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import sys +from op_test import OpTest + + +class TestCollectFPNProposalstOp(OpTest): + def set_data(self): + self.init_test_case() + self.make_rois() + self.scores_input = [('y%d' % i, + (self.scores[i].reshape(-1, 1), self.rois_lod[i])) + for i in range(self.num_level)] + self.rois, self.lod = self.calc_rois_collect() + inputs_x = [('x%d' % i, (self.roi_inputs[i][:, 1:], self.rois_lod[i])) + for i in range(self.num_level)] + self.inputs = { + 'MultiLevelRois': inputs_x, + "MultiLevelScores": self.scores_input + } + self.attrs = {'post_nms_topN': self.post_nms_top_n, } + self.outputs = {'FpnRois': (self.rois, [self.lod])} + + def init_test_case(self): + self.post_nms_top_n = 20 + self.images_shape = [100, 100] + + def resort_roi_by_batch_id(self, rois): + batch_id_list = rois[:, 0] + batch_size = int(batch_id_list.max()) + sorted_rois = [] + new_lod = [] + for batch_id in range(batch_size + 1): + sub_ind = np.where(batch_id_list == batch_id)[0] + sub_rois = rois[sub_ind, 1:] + sorted_rois.append(sub_rois) + new_lod.append(len(sub_rois)) + new_rois = np.concatenate(sorted_rois) + return new_rois, new_lod + + def calc_rois_collect(self): + roi_inputs = np.concatenate(self.roi_inputs) + scores = np.concatenate(self.scores) + inds = np.argsort(-scores)[:self.post_nms_top_n] + rois = roi_inputs[inds, :] + new_rois, new_lod = self.resort_roi_by_batch_id(rois) + return new_rois, new_lod + + def make_rois(self): + self.num_level = 4 + self.roi_inputs = [] + self.scores = [] + self.rois_lod = [[[20, 10]], [[30, 20]], [[20, 30]], [[10, 10]]] + for lvl in range(self.num_level): + rois = [] + scores_pb = [] + lod = self.rois_lod[lvl][0] + bno = 0 + for roi_num in lod: + for i in range(roi_num): + xywh = np.random.rand(4) + xy1 = xywh[0:2] * 20 + wh = xywh[2:4] * (self.images_shape - xy1) + xy2 = xy1 + wh + roi = [bno, xy1[0], xy1[1], xy2[0], xy2[1]] + rois.append(roi) + bno += 1 + scores_pb.extend(list(np.random.uniform(0.0, 1.0, roi_num))) + rois = np.array(rois).astype("float32") + self.roi_inputs.append(rois) + scores_pb = np.array(scores_pb).astype("float32") + self.scores.append(scores_pb) + + def setUp(self): + self.op_type = "collect_fpn_proposals" + self.set_data() + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main()