提交 c8610739 编写于 作者: C ceci3

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into npair_loss0

......@@ -330,6 +330,7 @@ paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varar
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '991e934c3e09abf0edec7c9c978b4691'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '005a5ae47d6c8fff721931d69d072b9f'))
paddle.fluid.layers.accuracy (ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)), ('document', '9808534c12c5e739a10f73ebb0b4eafd'))
paddle.fluid.layers.auc (ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)), ('document', 'e0e95334fce92d16c2d9db6e7caffc47'))
......
......@@ -37,8 +37,10 @@ detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc bo
if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS memory cub)
else()
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc)
endif()
detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
namespace paddle {
namespace operators {
class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("FpnRois"),
"Input(FpnRois) shouldn't be null");
PADDLE_ENFORCE_GE(
ctx->Outputs("MultiFpnRois").size(), 1UL,
"Outputs(MultiFpnRois) of DistributeOp should not be empty");
size_t min_level = static_cast<size_t>(ctx->Attrs().Get<int>("min_level"));
size_t max_level = static_cast<size_t>(ctx->Attrs().Get<int>("max_level"));
PADDLE_ENFORCE_GE(max_level, min_level,
"max_level must not lower than min_level");
// Set the output shape
size_t num_out_rois = max_level - min_level + 1;
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(num_out_rois);
for (size_t i = 0; i < num_out_rois; ++i) {
framework::DDim out_dim = {-1, 4};
outs_dims.push_back(out_dim);
}
ctx->SetOutputsDim("MultiFpnRois", outs_dims);
ctx->SetOutputDim("RestoreIndex", {1, -1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois"));
return framework::OpKernelType(data_type, platform::CPUPlace());
}
};
class DistributeFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("FpnRois", "(LoDTensor) The rois at all levels in shape (-1, 4)");
AddOutput("MultiFpnRois", "(LoDTensor) Output with distribute operator")
.AsDuplicable();
AddOutput("RestoreIndex",
"(Tensor) An array of positive number which is "
"used to restore the order of FpnRois");
AddAttr<int>("min_level",
"The lowest level of FPN layer where the"
" proposals come from");
AddAttr<int>("max_level",
"The highest level of FPN layer where the"
" proposals come from");
AddAttr<int>("refer_level",
"The referring level of FPN layer with"
" specified scale");
AddAttr<int>("refer_scale",
"The referring scale of FPN layer with"
" specified level");
AddComment(R"DOC(
This operator distribute all proposals into different fpn level,
with respect to scale of the proposals, the referring scale and
the referring level. Besides, to restore the order of proposals,
we return an array which indicate the original index of rois in
current proposals.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(distribute_fpn_proposals, ops::DistributeFpnProposalsOp,
ops::DistributeFpnProposalsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals,
ops::DistributeFpnProposalsOpKernel<float>,
ops::DistributeFpnProposalsOpKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/memory/allocation/allocator.h>
#include "cub/cub.cuh"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
int const BBoxSize = 4;
struct RangeInitFunctor {
int start_;
int delta_;
int* out_;
__device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
static inline void TransLoD(const int* length_lod, const int lod_size,
int* offset_lod) {
int offset = 0;
for (int i = 0; i < lod_size; ++i) {
offset_lod[i] = offset;
offset += length_lod[i];
}
}
template <typename T>
static __device__ inline T RoIArea(const T* box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static __global__ void GPUDistFpnProposalsHelper(
const int nthreads, const T* rois, const int lod_size,
const int refer_level, const int refer_scale, const int max_level,
const int min_level, int* roi_batch_id_data, int* sub_lod_list,
int* target_lvls) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const T* offset_roi = rois + i * BBoxSize;
int roi_batch_ind = roi_batch_id_data[i];
// get the target level of current rois
T roi_area = RoIArea(offset_roi, false);
T roi_scale = sqrt(roi_area);
int tgt_lvl = floor(log2(roi_scale / refer_scale) + refer_level);
tgt_lvl = min(max_level, max(tgt_lvl, min_level));
target_lvls[i] = tgt_lvl;
// compute number of rois in the same batch and same target level
platform::CudaAtomicAdd(sub_lod_list + tgt_lvl * lod_size + roi_batch_ind,
1);
}
}
template <typename DeviceContext, typename T>
class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* fpn_rois = ctx.Input<paddle::framework::LoDTensor>("FpnRois");
auto multi_fpn_rois = ctx.MultiOutput<LoDTensor>("MultiFpnRois");
auto* restore_index = ctx.Output<Tensor>("RestoreIndex");
const int min_level = ctx.Attr<int>("min_level");
const int max_level = ctx.Attr<int>("max_level");
const int refer_level = ctx.Attr<int>("refer_level");
const int refer_scale = ctx.Attr<int>("refer_scale");
int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), 1UL,
"DistributeFpnProposalsOp need 1 level of LoD");
auto fpn_rois_lod = fpn_rois->lod().back();
int lod_size = fpn_rois_lod.size() - 1;
int roi_num = fpn_rois_lod[lod_size];
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// get batch id by lod in CPU
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({roi_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
for (int n = 0; n < lod_size; ++n) {
for (size_t i = fpn_rois_lod[n]; i < fpn_rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
// copy batch id list to GPU
Tensor roi_batch_id_list_gpu;
framework::TensorCopySync(roi_batch_id_list, dev_ctx.GetPlace(),
&roi_batch_id_list_gpu);
Tensor sub_lod_list;
sub_lod_list.Resize({num_level, lod_size});
int* sub_lod_list_data = sub_lod_list.mutable_data<int>(dev_ctx.GetPlace());
Tensor target_lvls;
target_lvls.Resize({roi_num});
int* target_lvls_data = target_lvls.mutable_data<int>(dev_ctx.GetPlace());
int blocks = NumBlocks(roi_num);
int threads = kNumCUDAThreads;
// get target levels and sub_lod list
GPUDistFpnProposalsHelper<T><<<blocks, threads>>>(
roi_num, fpn_rois->data<T>(), lod_size, refer_level, refer_scale,
max_level, min_level, roi_batch_id_list_gpu.data<int>(),
sub_lod_list_data, target_lvls_data);
Tensor index_in_t;
int* idx_in = index_in_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx, roi_num);
for_range(RangeInitFunctor{0, 1, idx_in});
Tensor keys_out_t;
int* keys_out = keys_out_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
Tensor index_out_t;
int* idx_out = index_out_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
// Determine temporary device storage requirements
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<int, int>(
nullptr, temp_storage_bytes, target_lvls_data, keys_out, idx_in,
idx_out, roi_num);
// Allocate temporary storage
auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes,
memory::Allocator::kScratchpad);
// Run sorting operation
// sort target level to get corresponding index
cub::DeviceRadixSort::SortPairsDescending<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out,
idx_in, idx_out, roi_num);
int* restore_idx_data =
restore_index->mutable_data<int>({roi_num, 1}, dev_ctx.GetPlace());
// sort current index to get restore index
cub::DeviceRadixSort::SortPairsDescending<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in,
restore_idx_data, roi_num);
Tensor offset_lod;
int* offset_lod_data =
offset_lod.mutable_data<int>({lod_size + 1}, dev_ctx.GetPlace());
for (int i = 0; i < num_level; ++i) {
Tensor sub_lod = sub_lod_list.Slice(i, i + 1);
int* sub_lod_data = sub_lod.data<int>();
// transfer length-based lod to offset-based lod
TransLoD(sub_lod_data, lod_size + 1, offset_lod_data);
int sub_rois_num = offset_lod_data[lod_size];
Tensor sub_idx = index_out_t.Slice(0, sub_rois_num);
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace());
GPUGather<T>(dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]);
framework::LoD lod;
std::vector<size_t> offset;
memory::Copy(platform::CPUPlace(), offset.data(), place, offset_lod_data,
sizeof(int) * (lod_size + 1), 0);
lod.emplace_back(offset);
multi_fpn_rois[i]->set_lod(lod);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
distribute_fpn_proposals,
ops::GPUDistributeFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
float>,
ops::GPUDistributeFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
const int kBoxDim = 4;
template <typename T>
static inline T BBoxArea(const T* box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <typename T>
class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* fpn_rois = context.Input<paddle::framework::LoDTensor>("FpnRois");
auto multi_fpn_rois =
context.MultiOutput<paddle::framework::LoDTensor>("MultiFpnRois");
auto* restore_index =
context.Output<paddle::framework::Tensor>("RestoreIndex");
const int min_level = context.Attr<int>("min_level");
const int max_level = context.Attr<int>("max_level");
const int refer_level = context.Attr<int>("refer_level");
const int refer_scale = context.Attr<int>("refer_scale");
const int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), 1UL,
"DistributeFpnProposalsOp need 1 level of LoD");
auto fpn_rois_lod = fpn_rois->lod().back();
int fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
std::vector<int> target_level;
// std::vector<int> target_level(fpn_rois_num, -1);
// record the number of rois in each level
std::vector<int> num_rois_level(num_level, 0);
std::vector<int> num_rois_level_integral(num_level + 1, 0);
for (int i = 0; i < fpn_rois_lod.size() - 1; ++i) {
Tensor fpn_rois_slice =
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>();
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
// get the target level of current rois
T roi_scale = std::sqrt(BBoxArea(rois_data, false));
int tgt_lvl =
std::floor(std::log2(roi_scale / refer_scale) + refer_level);
tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level));
target_level.push_back(tgt_lvl);
num_rois_level[tgt_lvl - min_level]++;
rois_data += kBoxDim;
}
}
// define the output rois
// pointer which point to each level fpn rois
std::vector<T*> multi_fpn_rois_data(num_level);
// lod0 which will record the offset information of each level rois
std::vector<std::vector<size_t>> multi_fpn_rois_lod0;
for (int i = 0; i < num_level; ++i) {
// allocate memory for each level rois
multi_fpn_rois[i]->mutable_data<T>({num_rois_level[i], kBoxDim},
context.GetPlace());
multi_fpn_rois_data[i] = multi_fpn_rois[i]->data<T>();
std::vector<size_t> lod0(1, 0);
multi_fpn_rois_lod0.push_back(lod0);
// statistic start point for each level rois
num_rois_level_integral[i + 1] =
num_rois_level_integral[i] + num_rois_level[i];
}
restore_index->mutable_data<int>({1, fpn_rois_num}, context.GetPlace());
int* restore_index_data = restore_index->data<int>();
std::vector<int> restore_index_inter(fpn_rois_num, -1);
// distribute the rois into different fpn level by target level
for (int i = 0; i < fpn_rois_lod.size() - 1; ++i) {
Tensor fpn_rois_slice =
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>();
size_t cur_offset = fpn_rois_lod[i];
// std::vector<size_t > lod_offset[num_level];
for (int j = 0; j < num_level; j++) {
multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]);
}
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
int lvl = target_level[cur_offset + j];
memcpy(multi_fpn_rois_data[lvl - min_level], rois_data,
kBoxDim * sizeof(T));
multi_fpn_rois_data[lvl - min_level] += kBoxDim;
int index_in_shuffle = num_rois_level_integral[lvl - min_level] +
multi_fpn_rois_lod0[lvl - min_level][i + 1];
restore_index_inter[index_in_shuffle] = cur_offset + j;
multi_fpn_rois_lod0[lvl - min_level][i + 1]++;
rois_data += kBoxDim;
}
}
for (int i = 0; i < fpn_rois_num; ++i) {
restore_index_data[restore_index_inter[i]] = i;
}
// merge lod information into LoDTensor
for (int i = 0; i < num_level; ++i) {
framework::LoD lod;
lod.emplace_back(multi_fpn_rois_lod0[i]);
multi_fpn_rois[i]->set_lod(lod);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,7 @@ __all__ = [
'yolov3_loss',
'box_clip',
'multiclass_nms',
'distribute_fpn_proposals',
'box_decoder_and_assign',
]
......@@ -2224,6 +2225,79 @@ def multiclass_nms(bboxes,
return output
def distribute_fpn_proposals(fpn_rois,
min_level,
max_level,
refer_level,
refer_scale,
name=None):
"""
In Feature Pyramid Networks (FPN) models, it is needed to distribute all
proposals into different FPN level, with respect to scale of the proposals,
the referring scale and the referring level. Besides, to restore the order
of proposals, we return an array which indicates the original index of rois
in current proposals. To compute FPN level for each roi, the formula is
given as follows:
.. math::
roi\_scale &= \sqrt{BBoxArea(fpn\_roi)}
level = floor(&\log(\\frac{roi\_scale}{refer\_scale}) + refer\_level)
where BBoxArea is a function to compute the area of each roi.
Args:
fpn_rois(variable): The input fpn_rois, the second dimension is 4.
min_level(int): The lowest level of FPN layer where the proposals come
from.
max_level(int): The highest level of FPN layer where the proposals
come from.
refer_level(int): The referring level of FPN layer with specified scale.
refer_scale(int): The referring scale of FPN layer with specified level.
name(str|None): The name of this operator.
Returns:
tuple:
A tuple(multi_rois, restore_ind) is returned. The multi_rois is
a list of segmented tensor variables. The restore_ind is a 2D
Tensor with shape [N, 1], N is the number of total rois. It is
used to restore the order of fpn_rois.
Examples:
.. code-block:: python
fpn_rois = fluid.layers.data(
name='data', shape=[4], dtype='float32', lod_level=1)
multi_rois, restore_ind = fluid.layers.distribute_fpn_proposals(
fpn_rois=fpn_rois,
min_level=2,
max_level=5,
refer_level=4,
refer_scale=224)
"""
helper = LayerHelper('distribute_fpn_proposals', **locals())
dtype = helper.input_dtype()
num_lvl = max_level - min_level + 1
multi_rois = [
helper.create_variable_for_type_inference(dtype) for i in range(num_lvl)
]
restore_ind = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
type='distribute_fpn_proposals',
inputs={'FpnRois': fpn_rois},
outputs={'MultiFpnRois': multi_rois,
'RestoreIndex': restore_ind},
attrs={
'min_level': min_level,
'max_level': max_level,
'refer_level': refer_level,
'refer_scale': refer_scale
})
return multi_rois, restore_ind
@templatedoc()
def box_decoder_and_assign(prior_box,
prior_box_var,
......
......@@ -504,5 +504,21 @@ class TestMulticlassNMS(unittest.TestCase):
self.assertIsNotNone(output)
class TestDistributeFpnProposals(unittest.TestCase):
def test_distribute_fpn_proposals(self):
program = Program()
with program_guard(program):
fpn_rois = fluid.layers.data(
name='data', shape=[4], dtype='float32', lod_level=1)
multi_rois, restore_ind = layers.distribute_fpn_proposals(
fpn_rois=fpn_rois,
min_level=2,
max_level=5,
refer_level=4,
refer_scale=224)
self.assertIsNotNone(multi_rois)
self.assertIsNotNone(restore_ind)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import sys
from op_test import OpTest
class TestDistributeFPNProposalsOp(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.rois_fpn, self.rois_idx_restore = self.calc_rois_distribute()
self.inputs = {'FpnRois': (self.rois[:, 1:5], self.rois_lod)}
self.attrs = {
'max_level': self.roi_max_level,
'min_level': self.roi_min_level,
'refer_scale': self.canonical_scale,
'refer_level': self.canonical_level
}
output = [('out%d' % i, self.rois_fpn[i])
for i in range(len(self.rois_fpn))]
self.outputs = {
'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore
}
def init_test_case(self):
self.roi_max_level = 5
self.roi_min_level = 2
self.canonical_scale = 224
self.canonical_level = 4
self.images_shape = [512, 512]
def boxes_area(self, boxes):
w = (boxes[:, 2] - boxes[:, 0] + 1)
h = (boxes[:, 3] - boxes[:, 1] + 1)
areas = w * h
assert np.all(areas >= 0), 'Negative areas founds'
return areas
def map_rois_to_fpn_levels(self, rois, lvl_min, lvl_max):
s = np.sqrt(self.boxes_area(rois))
s0 = self.canonical_scale
lvl0 = self.canonical_level
target_lvls = np.floor(lvl0 + np.log2(s / s0 + 1e-6))
target_lvls = np.clip(target_lvls, lvl_min, lvl_max)
return target_lvls
def get_sub_lod(self, sub_lvl):
sub_lod = []
max_batch_id = sub_lvl[-1]
for i in range(max_batch_id.astype(np.int32) + 1):
sub_lod.append(np.where(sub_lvl == i)[0].size)
return sub_lod
def add_multilevel_roi(self, rois, target_lvls, lvl_min, lvl_max):
rois_idx_order = np.empty((0, ))
rois_fpn = []
for lvl in range(lvl_min, lvl_max + 1):
idx_lvl = np.where(target_lvls == lvl)[0]
if len(idx_lvl) == 0:
rois_fpn.append((np.empty(shape=(0, 4)), [[0, 0]]))
continue
sub_lod = self.get_sub_lod(rois[idx_lvl, 0])
rois_fpn.append((rois[idx_lvl, 1:], [sub_lod]))
rois_idx_order = np.concatenate((rois_idx_order, idx_lvl))
rois_idx_restore = np.argsort(rois_idx_order).astype(
np.int32, copy=False)
return rois_fpn, rois_idx_restore
def calc_rois_distribute(self):
lvl_min = self.roi_min_level
lvl_max = self.roi_max_level
target_lvls = self.map_rois_to_fpn_levels(self.rois[:, 1:5], lvl_min,
lvl_max)
rois_fpn, rois_idx_restore = self.add_multilevel_roi(
self.rois, target_lvls, lvl_min, lvl_max)
return rois_fpn, rois_idx_restore
def make_rois(self):
self.rois_lod = [[100, 200]]
rois = []
lod = self.rois_lod[0]
bno = 0
for roi_num in lod:
for i in range(roi_num):
xywh = np.random.rand(4)
xy1 = xywh[0:2] * 20
wh = xywh[2:4] * (self.images_shape - xy1)
xy2 = xy1 + wh
roi = [bno, xy1[0], xy1[1], xy2[0], xy2[1]]
rois.append(roi)
bno += 1
self.rois = np.array(rois).astype("float32")
def setUp(self):
self.op_type = "distribute_fpn_proposals"
self.set_data()
def test_check_output(self):
self.check_output()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册