未验证 提交 1c6d0646 编写于 作者: J jerrywgz 提交者: GitHub

add collect fpn proposals op,test=develop (#16074)

* add collect fpn proposals op,test=develop
上级 60be66e2
......@@ -360,6 +360,7 @@ paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5'))
paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace'))
paddle.fluid.layers.accuracy (ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)), ('document', '9808534c12c5e739a10f73ebb0b4eafd'))
paddle.fluid.layers.auc (ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)), ('document', 'e0e95334fce92d16c2d9db6e7caffc47'))
paddle.fluid.layers.exponential_decay (ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)), ('document', '98a5050bee8522fcea81aa795adaba51'))
......
......@@ -39,9 +39,11 @@ detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc bo
if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS memory cub)
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS memory cub)
else()
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc)
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc)
endif()
detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu)
......
......@@ -22,10 +22,10 @@ namespace paddle {
namespace operators {
struct RangeInitFunctor {
int start_;
int delta_;
int* out_;
HOSTDEVICE void operator()(size_t i) { out_[i] = start_ + i * delta_; }
int start;
int delta;
int* out;
HOSTDEVICE void operator()(size_t i) { out[i] = start + i * delta; }
};
template <typename T>
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class CollectFpnProposalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs("MultiLevelRois"),
"Inputs(MultiLevelRois) shouldn't be null");
PADDLE_ENFORCE(context->HasInputs("MultiLevelScores"),
"Inputs(MultiLevelScores) shouldn't be null");
PADDLE_ENFORCE(context->HasOutput("FpnRois"),
"Outputs(MultiFpnRois) of DistributeOp should not be null");
auto roi_dims = context->GetInputsDim("MultiLevelRois");
auto score_dims = context->GetInputsDim("MultiLevelScores");
auto post_nms_topN = context->Attrs().Get<int>("post_nms_topN");
std::vector<int64_t> out_dims;
for (auto &roi_dim : roi_dims) {
PADDLE_ENFORCE_EQ(roi_dim[1], 4,
"Second dimension of Input(MultiLevelRois) must be 4");
}
for (auto &score_dim : score_dims) {
PADDLE_ENFORCE_EQ(
score_dim[1], 1,
"Second dimension of Input(MultiLevelScores) must be 1");
}
context->SetOutputDim("FpnRois", {post_nms_topN, 4});
if (!context->IsRuntime()) { // Runtime LoD infershape will be computed
// in Kernel.
context->ShareLoD("MultiLevelRois", "FpnRois");
}
if (context->IsRuntime()) {
std::vector<framework::InferShapeVarPtr> roi_inputs =
context->GetInputVarPtrs("MultiLevelRois");
std::vector<framework::InferShapeVarPtr> score_inputs =
context->GetInputVarPtrs("MultiLevelScores");
for (size_t i = 0; i < roi_inputs.size(); ++i) {
framework::Variable *roi_var =
boost::get<framework::Variable *>(roi_inputs[i]);
framework::Variable *score_var =
boost::get<framework::Variable *>(score_inputs[i]);
auto &roi_lod = roi_var->Get<LoDTensor>().lod();
auto &score_lod = score_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_EQ(roi_lod, score_lod,
"Inputs(MultiLevelRois) and Inputs(MultiLevelScores) "
"should have same lod.");
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type =
framework::GetDataTypeOfVar(ctx.MultiInputVar("MultiLevelRois")[0]);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
class CollectFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("MultiLevelRois",
"(LoDTensor) Multiple roi LoDTensors from each level in shape "
"(N, 4), N is the number of RoIs")
.AsDuplicable();
AddInput("MultiLevelScores",
"(LoDTensor) Multiple score LoDTensors from each level in shape"
" (N, 1), N is the number of RoIs.")
.AsDuplicable();
AddOutput("FpnRois", "(LoDTensor) All selected RoIs with highest scores");
AddAttr<int>("post_nms_topN",
"Select post_nms_topN RoIs from"
" all images and all fpn layers");
AddComment(R"DOC(
This operator concats all proposals from different images
and different FPN levels. Then sort all of those proposals
by objectness confidence. Select the post_nms_topN RoIs in
total. Finally, re-sort the RoIs in the order of batch index.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(collect_fpn_proposals, ops::CollectFpnProposalsOp,
ops::CollectFpnProposalsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(collect_fpn_proposals,
ops::CollectFpnProposalsOpKernel<float>,
ops::CollectFpnProposalsOpKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/memory/allocation/allocator.h>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kNumCUDAThreads = 64;
static constexpr int kNumMaxinumNumBlocks = 4096;
const int kBBoxSize = 4;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
static __global__ void GetLengthLoD(const int nthreads, const int* batch_ids,
int* length_lod) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (nthreads);
i += blockDim.x * gridDim.x) {
platform::CudaAtomicAdd(length_lod + batch_ids[i], 1);
}
}
template <typename DeviceContext, typename T>
class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto roi_ins = ctx.MultiInput<LoDTensor>("MultiLevelRois");
const auto score_ins = ctx.MultiInput<LoDTensor>("MultiLevelScores");
auto fpn_rois = ctx.Output<LoDTensor>("FpnRois");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int post_nms_topN = ctx.Attr<int>("post_nms_topN");
// concat inputs along axis = 0
int roi_offset = 0;
int score_offset = 0;
int total_roi_num = 0;
for (size_t i = 0; i < roi_ins.size(); ++i) {
total_roi_num += roi_ins[i]->dims()[0];
}
int real_post_num = min(post_nms_topN, total_roi_num);
fpn_rois->mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
Tensor concat_rois;
Tensor concat_scores;
T* concat_rois_data = concat_rois.mutable_data<T>(
{total_roi_num, kBBoxSize}, dev_ctx.GetPlace());
T* concat_scores_data =
concat_scores.mutable_data<T>({total_roi_num, 1}, dev_ctx.GetPlace());
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({total_roi_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
int index = 0;
int lod_size;
auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
for (size_t i = 0; i < roi_ins.size(); ++i) {
auto roi_in = roi_ins[i];
auto score_in = score_ins[i];
auto roi_lod = roi_in->lod().back();
lod_size = roi_lod.size() - 1;
for (size_t n = 0; n < lod_size; ++n) {
for (size_t j = roi_lod[n]; j < roi_lod[n + 1]; ++j) {
roi_batch_id_data[index++] = n;
}
}
memory::Copy(place, concat_rois_data + roi_offset, place,
roi_in->data<T>(), roi_in->numel() * sizeof(T),
dev_ctx.stream());
memory::Copy(place, concat_scores_data + score_offset, place,
score_in->data<T>(), score_in->numel() * sizeof(T),
dev_ctx.stream());
roi_offset += roi_in->numel();
score_offset += score_in->numel();
}
// copy batch id list to GPU
Tensor roi_batch_id_list_gpu;
framework::TensorCopy(roi_batch_id_list, dev_ctx.GetPlace(),
&roi_batch_id_list_gpu);
Tensor index_in_t;
int* idx_in =
index_in_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace());
platform::ForRange<platform::CUDADeviceContext> for_range_total(
dev_ctx, total_roi_num);
for_range_total(RangeInitFunctor{0, 1, idx_in});
Tensor keys_out_t;
T* keys_out =
keys_out_t.mutable_data<T>({total_roi_num}, dev_ctx.GetPlace());
Tensor index_out_t;
int* idx_out =
index_out_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace());
// Determine temporary device storage requirements
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>(
nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in,
idx_out, total_roi_num);
// Allocate temporary storage
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes,
memory::Allocator::kScratchpad);
// Run sorting operation
// sort score to get corresponding index
cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(),
keys_out, idx_in, idx_out, total_roi_num);
index_out_t.Resize({real_post_num});
Tensor sorted_rois;
sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
Tensor sorted_batch_id;
sorted_batch_id.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
GPUGather<T>(dev_ctx, concat_rois, index_out_t, &sorted_rois);
GPUGather<int>(dev_ctx, roi_batch_id_list_gpu, index_out_t,
&sorted_batch_id);
Tensor batch_index_t;
int* batch_idx_in =
batch_index_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
platform::ForRange<platform::CUDADeviceContext> for_range_post(
dev_ctx, real_post_num);
for_range_post(RangeInitFunctor{0, 1, batch_idx_in});
Tensor out_id_t;
int* out_id_data =
out_id_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
// Determine temporary device storage requirements
temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs<int, int>(
nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data,
batch_idx_in, index_out_t.data<int>(), real_post_num);
// Allocate temporary storage
d_temp_storage = memory::Alloc(place, temp_storage_bytes,
memory::Allocator::kScratchpad);
// Run sorting operation
// sort batch_id to get corresponding index
cub::DeviceRadixSort::SortPairs<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(),
out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num);
GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois);
Tensor length_lod;
int* length_lod_data =
length_lod.mutable_data<int>({lod_size}, dev_ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, int> set_zero;
set_zero(dev_ctx, &length_lod, static_cast<int>(0));
int blocks = NumBlocks(real_post_num);
int threads = kNumCUDAThreads;
// get length-based lod by batch ids
GetLengthLoD<<<blocks, threads>>>(real_post_num, out_id_data,
length_lod_data);
std::vector<int> length_lod_cpu(lod_size);
memory::Copy(platform::CPUPlace(), length_lod_cpu.data(), place,
length_lod_data, sizeof(int) * lod_size, dev_ctx.stream());
dev_ctx.Wait();
std::vector<size_t> offset(1, 0);
for (int i = 0; i < lod_size; ++i) {
offset.emplace_back(offset.back() + length_lod_cpu[i]);
}
framework::LoD lod;
lod.emplace_back(offset);
fpn_rois->set_lod(lod);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
collect_fpn_proposals,
ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
float>,
ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
#pragma once
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
const int kBoxDim = 4;
template <typename T>
struct ScoreWithID {
T score;
int batch_id;
int index;
int level;
ScoreWithID() {
batch_id = -1;
index = -1;
level = -1;
}
ScoreWithID(T score_, int batch_id_, int index_, int level_) {
score = score_;
batch_id = batch_id_;
index = index_;
level = level_;
}
};
template <typename T>
static inline bool CompareByScore(ScoreWithID<T> a, ScoreWithID<T> b) {
return a.score >= b.score;
}
template <typename T>
static inline bool CompareByBatchid(ScoreWithID<T> a, ScoreWithID<T> b) {
return a.batch_id < b.batch_id;
}
template <typename T>
class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto multi_layer_rois =
context.MultiInput<paddle::framework::LoDTensor>("MultiLevelRois");
auto multi_layer_scores =
context.MultiInput<paddle::framework::LoDTensor>("MultiLevelScores");
auto* fpn_rois = context.Output<paddle::framework::LoDTensor>("FpnRois");
int post_nms_topN = context.Attr<int>("post_nms_topN");
PADDLE_ENFORCE_GE(post_nms_topN, 0UL,
"The parameter post_nms_topN must be a positive integer");
// assert that the length of Rois and scores are same
PADDLE_ENFORCE(multi_layer_rois.size() == multi_layer_scores.size(),
"DistributeFpnProposalsOp need 1 level of LoD");
// Check if the lod information of two LoDTensor is same
const int num_fpn_level = multi_layer_rois.size();
std::vector<int> integral_of_all_rois(num_fpn_level + 1, 0);
for (int i = 0; i < num_fpn_level; ++i) {
auto cur_rois_lod = multi_layer_rois[i]->lod().back();
integral_of_all_rois[i + 1] =
integral_of_all_rois[i] + cur_rois_lod[cur_rois_lod.size() - 1];
}
// concatenate all fpn rois scores into a list
// create a vector to store all scores
std::vector<ScoreWithID<T>> scores_of_all_rois(
integral_of_all_rois[num_fpn_level], ScoreWithID<T>());
for (int i = 0; i < num_fpn_level; ++i) {
const T* cur_level_scores = multi_layer_scores[i]->data<T>();
int cur_level_num = integral_of_all_rois[i + 1] - integral_of_all_rois[i];
auto cur_scores_lod = multi_layer_scores[i]->lod().back();
int cur_batch_id = 0;
for (int j = 0; j < cur_level_num; ++j) {
if (j >= cur_scores_lod[cur_batch_id + 1]) {
cur_batch_id++;
}
int cur_index = j + integral_of_all_rois[i];
scores_of_all_rois[cur_index].score = cur_level_scores[j];
scores_of_all_rois[cur_index].index = j;
scores_of_all_rois[cur_index].level = i;
scores_of_all_rois[cur_index].batch_id = cur_batch_id;
}
}
// keep top post_nms_topN rois
// sort the rois by the score
if (post_nms_topN > integral_of_all_rois[num_fpn_level]) {
post_nms_topN = integral_of_all_rois[num_fpn_level];
}
std::stable_sort(scores_of_all_rois.begin(), scores_of_all_rois.end(),
CompareByScore<T>);
scores_of_all_rois.resize(post_nms_topN);
// sort by batch id
std::stable_sort(scores_of_all_rois.begin(), scores_of_all_rois.end(),
CompareByBatchid<T>);
// create a pointer array
std::vector<const T*> multi_fpn_rois_data(num_fpn_level);
for (int i = 0; i < num_fpn_level; ++i) {
multi_fpn_rois_data[i] = multi_layer_rois[i]->data<T>();
}
// initialize the outputs
fpn_rois->mutable_data<T>({post_nms_topN, kBoxDim}, context.GetPlace());
T* fpn_rois_data = fpn_rois->data<T>();
std::vector<size_t> lod0(1, 0);
int cur_batch_id = 0;
for (int i = 0; i < post_nms_topN; ++i) {
int cur_fpn_level = scores_of_all_rois[i].level;
int cur_level_index = scores_of_all_rois[i].index;
memcpy(fpn_rois_data,
multi_fpn_rois_data[cur_fpn_level] + cur_level_index * kBoxDim,
kBoxDim * sizeof(T));
fpn_rois_data += kBoxDim;
if (scores_of_all_rois[i].batch_id != cur_batch_id) {
cur_batch_id = scores_of_all_rois[i].batch_id;
lod0.emplace_back(i);
}
}
lod0.emplace_back(post_nms_topN);
framework::LoD lod;
lod.emplace_back(lod0);
fpn_rois->set_lod(lod);
}
};
} // namespace operators
} // namespace paddle
......@@ -54,6 +54,7 @@ __all__ = [
'multiclass_nms',
'distribute_fpn_proposals',
'box_decoder_and_assign',
'collect_fpn_proposals',
]
......@@ -2512,3 +2513,68 @@ def box_decoder_and_assign(prior_box,
"OutputAssignBox": output_assign_box
})
return decoded_box, output_assign_box
def collect_fpn_proposals(multi_rois,
multi_scores,
min_level,
max_level,
post_nms_top_n,
name=None):
"""
Concat multi-level RoIs (Region of Interest) and select N RoIs
with respect to multi_scores. This operation performs the following steps:
1. Choose num_level RoIs and scores as input: num_level = max_level - min_level
2. Concat multi-level RoIs and scores
3. Sort scores and select post_nms_top_n scores
4. Gather RoIs by selected indices from scores
5. Re-sort RoIs by corresponding batch_id
Args:
multi_ros(list): List of RoIs to collect
multi_scores(list): List of scores
min_level(int): The lowest level of FPN layer to collect
max_level(int): The highest level of FPN layer to collect
post_nms_top_n(int): The number of selected RoIs
name(str|None): A name for this layer(optional)
Returns:
Variable: Output variable of selected RoIs.
Examples:
.. code-block:: python
multi_rois = []
multi_scores = []
for i in range(4):
multi_rois.append(fluid.layers.data(
name='roi_'+str(i), shape=[4], dtype='float32', lod_level=1))
for i in range(4):
multi_scores.append(fluid.layers.data(
name='score_'+str(i), shape=[1], dtype='float32', lod_level=1))
fpn_rois = fluid.layers.collect_fpn_proposals(
multi_rois=multi_rois,
multi_scores=multi_scores,
min_level=2,
max_level=5,
post_nms_top_n=2000)
"""
helper = LayerHelper('collect_fpn_proposals', **locals())
dtype = helper.input_dtype('multi_rois')
num_lvl = max_level - min_level + 1
input_rois = multi_rois[:num_lvl]
input_scores = multi_scores[:num_lvl]
output_rois = helper.create_variable_for_type_inference(dtype)
output_rois.stop_gradient = True
helper.append_op(
type='collect_fpn_proposals',
inputs={
'MultiLevelRois': input_rois,
'MultiLevelScores': input_scores
},
outputs={'FpnRois': output_rois},
attrs={'post_nms_topN': post_nms_top_n})
return output_rois
......@@ -522,6 +522,32 @@ class TestMulticlassNMS(unittest.TestCase):
self.assertIsNotNone(output)
class TestCollectFpnPropsals(unittest.TestCase):
def test_collect_fpn_proposals(self):
program = Program()
with program_guard(program):
multi_bboxes = []
multi_scores = []
for i in range(4):
bboxes = layers.data(
name='rois' + str(i),
shape=[10, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
scores = layers.data(
name='scores' + str(i),
shape=[10, 1],
dtype='float32',
lod_level=1,
append_batch_size=False)
multi_bboxes.append(bboxes)
multi_scores.append(scores)
fpn_rois = layers.collect_fpn_proposals(multi_bboxes, multi_scores,
2, 5, 10)
self.assertIsNotNone(fpn_rois)
class TestDistributeFpnProposals(unittest.TestCase):
def test_distribute_fpn_proposals(self):
program = Program()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import sys
from op_test import OpTest
class TestCollectFPNProposalstOp(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.scores_input = [('y%d' % i,
(self.scores[i].reshape(-1, 1), self.rois_lod[i]))
for i in range(self.num_level)]
self.rois, self.lod = self.calc_rois_collect()
inputs_x = [('x%d' % i, (self.roi_inputs[i][:, 1:], self.rois_lod[i]))
for i in range(self.num_level)]
self.inputs = {
'MultiLevelRois': inputs_x,
"MultiLevelScores": self.scores_input
}
self.attrs = {'post_nms_topN': self.post_nms_top_n, }
self.outputs = {'FpnRois': (self.rois, [self.lod])}
def init_test_case(self):
self.post_nms_top_n = 20
self.images_shape = [100, 100]
def resort_roi_by_batch_id(self, rois):
batch_id_list = rois[:, 0]
batch_size = int(batch_id_list.max())
sorted_rois = []
new_lod = []
for batch_id in range(batch_size + 1):
sub_ind = np.where(batch_id_list == batch_id)[0]
sub_rois = rois[sub_ind, 1:]
sorted_rois.append(sub_rois)
new_lod.append(len(sub_rois))
new_rois = np.concatenate(sorted_rois)
return new_rois, new_lod
def calc_rois_collect(self):
roi_inputs = np.concatenate(self.roi_inputs)
scores = np.concatenate(self.scores)
inds = np.argsort(-scores)[:self.post_nms_top_n]
rois = roi_inputs[inds, :]
new_rois, new_lod = self.resort_roi_by_batch_id(rois)
return new_rois, new_lod
def make_rois(self):
self.num_level = 4
self.roi_inputs = []
self.scores = []
self.rois_lod = [[[20, 10]], [[30, 20]], [[20, 30]], [[10, 10]]]
for lvl in range(self.num_level):
rois = []
scores_pb = []
lod = self.rois_lod[lvl][0]
bno = 0
for roi_num in lod:
for i in range(roi_num):
xywh = np.random.rand(4)
xy1 = xywh[0:2] * 20
wh = xywh[2:4] * (self.images_shape - xy1)
xy2 = xy1 + wh
roi = [bno, xy1[0], xy1[1], xy2[0], xy2[1]]
rois.append(roi)
bno += 1
scores_pb.extend(list(np.random.uniform(0.0, 1.0, roi_num)))
rois = np.array(rois).astype("float32")
self.roi_inputs.append(rois)
scores_pb = np.array(scores_pb).astype("float32")
self.scores.append(scores_pb)
def setUp(self):
self.op_type = "collect_fpn_proposals"
self.set_data()
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册