提交 0f652f30 编写于 作者: J jerrywgz

add distribute fpn proposals op, test=develop

上级 685a20ef
......@@ -327,6 +327,7 @@ paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], vararg
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_clip ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
paddle.fluid.layers.distribute_fpn_proposals ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
......
......@@ -33,6 +33,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc)
if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
namespace paddle {
namespace operators {
class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("FpnRois"),
"Input(FpnRois) shouldn't be null");
PADDLE_ENFORCE_GE(
ctx->Outputs("MultiFpnRois").size(), 1UL,
"Outputs(MultiFpnRois) of DistributeOp should not be empty");
size_t min_level = static_cast<size_t>(ctx->Attrs().Get<int>("min_level"));
size_t max_level = static_cast<size_t>(ctx->Attrs().Get<int>("max_level"));
PADDLE_ENFORCE_GE(max_level, min_level,
"max_level must not lower than min_level");
// Set the output shape
size_t num_out_rois = max_level - min_level + 1;
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(num_out_rois);
for (size_t i = 0; i < num_out_rois; ++i) {
framework::DDim out_dim = {-1, 4};
outs_dims.push_back(out_dim);
}
ctx->SetOutputsDim("MultiFpnRois", outs_dims);
ctx->SetOutputDim("RestoreIndex", {1, -1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois"));
return framework::OpKernelType(data_type, platform::CPUPlace());
}
};
class DistributeFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("FpnRois", "(LoDTensor) The rois at all levels in shape (-1, 4)");
AddOutput("MultiFpnRois", "(LoDTensor) Output with distribute operator")
.AsDuplicable();
AddOutput("RestoreIndex",
"(Tensor) An array of positive number which is "
"used to restore the order of FpnRois");
AddAttr<int>("min_level",
"The lowest level of FPN layer where the"
" proposals come from");
AddAttr<int>("max_level",
"The highest level of FPN layer where the"
" proposals come from");
AddAttr<int>("refer_level",
"The referring level of FPN layer with"
" specified scale");
AddAttr<int>("refer_scale",
"The referring scale of FPN layer with"
" specified level");
AddComment(R"DOC(
This operator distribute all proposals into different fpn level,
with respect to scale of the proposals, the referring scale and
the referring level. Besides, to restore the order of proposals,
we return an array which indicate the original index of rois in
current proposals.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(distribute_fpn_proposals, ops::DistributeFpnProposalsOp,
ops::DistributeFpnProposalsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals,
ops::DistributeFpnProposalsOpKernel<float>,
ops::DistributeFpnProposalsOpKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
const int kBoxDim = 4;
template <typename T>
static inline T BBoxArea(const T* box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <typename T>
class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* fpn_rois = context.Input<paddle::framework::LoDTensor>("FpnRois");
auto multi_fpn_rois =
context.MultiOutput<paddle::framework::LoDTensor>("MultiFpnRois");
auto* restore_index =
context.Output<paddle::framework::Tensor>("RestoreIndex");
const int min_level = context.Attr<int>("min_level");
const int max_level = context.Attr<int>("max_level");
const int refer_level = context.Attr<int>("refer_level");
const int refer_scale = context.Attr<int>("refer_scale");
const int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), 1UL,
"DistributeFpnProposalsOp need 1 level of LoD");
auto fpn_rois_lod = fpn_rois->lod().back();
int fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
std::vector<int> target_level;
// std::vector<int> target_level(fpn_rois_num, -1);
// record the number of rois in each level
std::vector<int> num_rois_level(num_level, 0);
std::vector<int> num_rois_level_integral(num_level + 1, 0);
for (int i = 0; i < fpn_rois_lod.size() - 1; ++i) {
Tensor fpn_rois_slice =
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>();
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
// get the target level of current rois
T roi_scale = std::sqrt(BBoxArea(rois_data, false));
int tgt_lvl =
std::floor(std::log2(roi_scale / refer_scale) + refer_level);
tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level));
target_level.push_back(tgt_lvl);
num_rois_level[tgt_lvl - min_level]++;
rois_data += kBoxDim;
}
}
// define the output rois
// pointer which point to each level fpn rois
T* multi_fpn_rois_data[num_level];
// lod0 which will record the offset information of each level rois
std::vector<std::vector<size_t>> multi_fpn_rois_lod0;
for (int i = 0; i < num_level; ++i) {
// allocate memory for each level rois
multi_fpn_rois[i]->mutable_data<T>({num_rois_level[i], kBoxDim},
context.GetPlace());
multi_fpn_rois_data[i] = multi_fpn_rois[i]->data<T>();
std::vector<size_t> lod0(1, 0);
multi_fpn_rois_lod0.push_back(lod0);
// statistic start point for each level rois
num_rois_level_integral[i + 1] =
num_rois_level_integral[i] + num_rois_level[i];
}
restore_index->mutable_data<int>({1, fpn_rois_num}, context.GetPlace());
int* restore_index_data = restore_index->data<int>();
std::vector<int> restore_index_inter(fpn_rois_num, -1);
// distribute the rois into different fpn level by target level
for (int i = 0; i < fpn_rois_lod.size() - 1; ++i) {
Tensor fpn_rois_slice =
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>();
size_t cur_offset = fpn_rois_lod[i];
// std::vector<size_t > lod_offset[num_level];
for (int j = 0; j < num_level; j++) {
multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]);
}
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
int lvl = target_level[cur_offset + j];
memcpy(multi_fpn_rois_data[lvl - min_level], rois_data,
kBoxDim * sizeof(T));
multi_fpn_rois_data[lvl - min_level] += kBoxDim;
int index_in_shuffle = num_rois_level_integral[lvl - min_level] +
multi_fpn_rois_lod0[lvl - min_level][i + 1];
restore_index_inter[index_in_shuffle] = cur_offset + j;
multi_fpn_rois_lod0[lvl - min_level][i + 1]++;
rois_data += kBoxDim;
}
}
for (int i = 0; i < fpn_rois_num; ++i) {
restore_index_data[restore_index_inter[i]] = i;
}
// merge lod information into LoDTensor
for (int i = 0; i < num_level; ++i) {
framework::LoD lod;
lod.emplace_back(multi_fpn_rois_lod0[i]);
multi_fpn_rois[i]->set_lod(lod);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,7 @@ __all__ = [
'yolov3_loss',
'box_clip',
'multiclass_nms',
'distribute_fpn_proposals',
]
......@@ -2220,3 +2221,77 @@ def multiclass_nms(bboxes,
output.stop_gradient = True
return output
def distribute_fpn_proposals(fpn_rois,
min_level,
max_level,
refer_level,
refer_scale,
name=None):
"""
Distribute all proposals into different fpn level, with respect to scale
of the proposals, the referring scale and the referring level. Besides, to
restore the order of proposals, we return an array which indicate the
original index of rois in current proposals. To compute fpn level for each
roi, the formula is given as follows:
.. code-block:: text
roi_scale = sqrt(BBoxArea(fpn_roi));
level = floor(log2(roi_scale / refer_scale) + refer_level)
where BBoxArea is the function to compute the area of each roi:
.. code-block:: text
w = fpn_roi[2] - fpn_roi[0]
h = fpn_roi[3] - fpn_roi[1]
area = (w + 1) * (h + 1)
Args:
fpn_rois(variable): The input fpn_rois, the last dimension is 4.
min_level(int): The lowest level of FPN layer where the proposals come
from.
max_level(int): The highest level of FPN layer where the proposals
come from.
refer_level(int): The referring level of FPN layer with specified scale.
refer_scale(int): The referring scale of FPN layer with specified level.
Returns:
List(variable): The list of segmented tensor variables.
Variable: An array of positive number which is used to restore the
order of fpn_rois.
Examples:
.. code-block:: python
fpn_rois = fluid.layers.data(
name='data', shape=[4], dtype='float32', lod_level=1)
multi_rois, restore_ind = fluid.layers.distribute_fpn_proposals(
fpn_rois=fpn_rois,
min_level=2,
max_level=5,
refer_level=4,
refer_scale=224)
"""
helper = LayerHelper('distribute_fpn_proposals', **locals())
dtype = helper.input_dtype()
num_lvl = max_level - min_level + 1
multi_rois = [
helper.create_variable_for_type_inference(dtype) for i in range(num_lvl)
]
restore_ind = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
type='distribute_fpn_proposals',
inputs={'FpnRois': fpn_rois},
outputs={'MultiFpnRois': multi_rois,
'RestoreIndex': restore_ind},
attrs={
'min_level': min_level,
'max_level': max_level,
'refer_level': refer_level,
'refer_scale': refer_scale
})
return multi_rois, restore_ind
......@@ -504,5 +504,21 @@ class TestMulticlassNMS(unittest.TestCase):
self.assertIsNotNone(output)
class TestDistributeFpnProposals(unittest.TestCase):
def test_distribute_fpn_proposals(self):
program = Program()
with program_guard(program):
fpn_rois = fluid.layers.data(
name='data', shape=[4], dtype='float32', lod_level=1)
multi_rois, restore_ind = layers.distribute_fpn_proposals(
fpn_rois=fpn_rois,
min_level=2,
max_level=5,
refer_level=4,
refer_scale=224)
self.assertIsNotNone(multi_rois)
self.assertIsNotNone(restore_ind)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import sys
from op_test import OpTest
class TestDistributeFPNProposalsOp(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.rois_fpn, self.rois_idx_restore = self.calc_rois_distribute()
self.inputs = {'FpnRois': (self.rois[:, 1:5], self.rois_lod)}
self.attrs = {
'max_level': self.roi_max_level,
'min_level': self.roi_min_level,
'refer_scale': self.canonical_scale,
'refer_level': self.canonical_level
}
output = [('out%d' % i, self.rois_fpn[i])
for i in range(len(self.rois_fpn))]
self.outputs = {
'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore
}
def init_test_case(self):
self.roi_max_level = 5
self.roi_min_level = 2
self.canonical_scale = 224
self.canonical_level = 4
self.images_shape = [512, 512]
def boxes_area(self, boxes):
w = (boxes[:, 2] - boxes[:, 0] + 1)
h = (boxes[:, 3] - boxes[:, 1] + 1)
areas = w * h
assert np.all(areas >= 0), 'Negative areas founds'
return areas
def map_rois_to_fpn_levels(self, rois, lvl_min, lvl_max):
s = np.sqrt(self.boxes_area(rois))
s0 = self.canonical_scale
lvl0 = self.canonical_level
target_lvls = np.floor(lvl0 + np.log2(s / s0 + 1e-6))
target_lvls = np.clip(target_lvls, lvl_min, lvl_max)
return target_lvls
def get_sub_lod(self, sub_lvl):
sub_lod = []
max_batch_id = sub_lvl[-1]
for i in range(max_batch_id.astype(np.int32) + 1):
sub_lod.append(np.where(sub_lvl == i)[0].size)
return sub_lod
def add_multilevel_roi(self, rois, target_lvls, lvl_min, lvl_max):
rois_idx_order = np.empty((0, ))
rois_fpn = []
for lvl in range(lvl_min, lvl_max + 1):
idx_lvl = np.where(target_lvls == lvl)[0]
if len(idx_lvl) == 0:
rois_fpn.append((np.empty(shape=(0, 4)), [[0, 0]]))
continue
sub_lod = self.get_sub_lod(rois[idx_lvl, 0])
rois_fpn.append((rois[idx_lvl, 1:], [sub_lod]))
rois_idx_order = np.concatenate((rois_idx_order, idx_lvl))
rois_idx_restore = np.argsort(rois_idx_order).astype(
np.int32, copy=False)
return rois_fpn, rois_idx_restore
def calc_rois_distribute(self):
lvl_min = self.roi_min_level
lvl_max = self.roi_max_level
target_lvls = self.map_rois_to_fpn_levels(self.rois[:, 1:5], lvl_min,
lvl_max)
rois_fpn, rois_idx_restore = self.add_multilevel_roi(
self.rois, target_lvls, lvl_min, lvl_max)
return rois_fpn, rois_idx_restore
def make_rois(self):
self.rois_lod = [[100, 200]]
rois = []
lod = self.rois_lod[0]
bno = 0
for roi_num in lod:
for i in range(roi_num):
xywh = np.random.rand(4)
xy1 = xywh[0:2] * 20
wh = xywh[2:4] * (self.images_shape - xy1)
xy2 = xy1 + wh
roi = [bno, xy1[0], xy1[1], xy2[0], xy2[1]]
rois.append(roi)
bno += 1
self.rois = np.array(rois).astype("float32")
def setUp(self):
self.op_type = "distribute_fpn_proposals"
self.set_data()
def test_check_output(self):
self.check_output()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册