未验证 提交 d1e1f174 编写于 作者: W wangguanzhong 提交者: GitHub

fix generate_proposal_labels in cascade-rcnn series model, test=develop (#27892)

* fix generate_proposal_labels in cascade-rcnn series model, test=develop

* fix example code & unittest, test=develop

* update code from review comments, test=develop
上级 afe68cb9
......@@ -149,5 +149,20 @@ void ClipTiledBoxes(const platform::DeviceContext& ctx,
}
}
// Calculate max IoU between each box and ground-truth and
// each row represents one box
template <typename T>
void MaxIoU(const framework::Tensor& iou, framework::Tensor* max_iou) {
const T* iou_data = iou.data<T>();
int row = iou.dims()[0];
int col = iou.dims()[1];
T* max_iou_data = max_iou->data<T>();
for (int i = 0; i < row; ++i) {
const T* v = iou_data + i * col;
T max_v = *std::max_element(v, v + col);
max_iou_data[i] = max_v;
}
}
} // namespace operators
} // namespace paddle
......@@ -33,6 +33,28 @@ void AppendRois(LoDTensor* out, int64_t offset, Tensor* to_add) {
memcpy(out_data + offset, to_add_data, to_add->numel() * sizeof(T));
}
// Filter the ground-truth in RoIs and the RoIs with non-positive area.
// The ground-truth has max overlap with itself so the max_overlap is 1
// and the corresponding RoI will be removed.
template <typename T>
void FilterRoIs(const platform::DeviceContext& ctx, const Tensor& rpn_rois,
const Tensor& max_overlap, Tensor* keep) {
const T* rpn_rois_dt = rpn_rois.data<T>();
const T* max_overlap_dt = max_overlap.data<T>();
int rois_num = max_overlap.numel();
keep->Resize({rois_num});
int* keep_data = keep->mutable_data<int>(ctx.GetPlace());
int keep_len = 0;
for (int i = 0; i < rois_num; ++i) {
if ((rpn_rois_dt[i * 4 + 2] - rpn_rois_dt[i * 4 + 0] + 1) > 0 &&
(rpn_rois_dt[i * 4 + 3] - rpn_rois_dt[i * 4 + 1] + 1) > 0 &&
max_overlap_dt[i] < 1.) {
keep_data[keep_len++] = i;
}
}
keep->Resize({keep_len});
}
class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -98,12 +120,21 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
im_info_dims.size(), im_info_dims));
int class_nums = ctx->Attrs().Get<int>("class_nums");
bool is_cascade_rcnn = ctx->Attrs().Get<bool>("is_cascade_rcnn");
if (is_cascade_rcnn) {
PADDLE_ENFORCE_EQ(
ctx->HasInput("MaxOverlap"), true,
platform::errors::NotFound(
"Input(MaxOverlap) of GenerateProposalLabelsOp "
"should not be null when is_cascade_rcnn is True."));
}
ctx->SetOutputDim("Rois", {-1, 4});
ctx->SetOutputDim("LabelsInt32", {-1, 1});
ctx->SetOutputDim("BboxTargets", {-1, 4 * class_nums});
ctx->SetOutputDim("BboxInsideWeights", {-1, 4 * class_nums});
ctx->SetOutputDim("BboxOutsideWeights", {-1, 4 * class_nums});
ctx->SetOutputDim("MaxOverlapWithGT", {-1});
}
protected:
......@@ -142,7 +173,6 @@ std::vector<std::vector<int>> SampleFgBgGt(
int64_t row = iou->dims()[0];
int64_t col = iou->dims()[1];
float epsilon = 0.00001;
const T* rpn_rois_dt = rpn_rois.data<T>();
// Follow the Faster RCNN's implementation
for (int64_t i = 0; i < row; ++i) {
const T* v = proposal_to_gt_overlaps + i * col;
......@@ -151,11 +181,6 @@ std::vector<std::vector<int>> SampleFgBgGt(
if ((i < gt_num) && (crowd_data[i])) {
max_overlap = -1.0;
}
if (is_cascade_rcnn &&
((rpn_rois_dt[i * 4 + 2] - rpn_rois_dt[i * 4 + 0] + 1) <= 0 ||
(rpn_rois_dt[i * 4 + 3] - rpn_rois_dt[i * 4 + 1] + 1) <= 0)) {
continue;
}
if (max_overlap >= fg_thresh) {
// fg mapped gt label index
for (int64_t j = 0; j < col; ++j) {
......@@ -232,12 +257,13 @@ std::vector<std::vector<int>> SampleFgBgGt(
template <typename T>
void GatherBoxesLabels(const platform::CPUDeviceContext& context,
const Tensor& boxes, const Tensor& gt_boxes,
const Tensor& gt_classes,
const Tensor& boxes, const Tensor& max_overlap,
const Tensor& gt_boxes, const Tensor& gt_classes,
const std::vector<int>& fg_inds,
const std::vector<int>& bg_inds,
const std::vector<int>& gt_inds, Tensor* sampled_boxes,
Tensor* sampled_labels, Tensor* sampled_gts) {
Tensor* sampled_labels, Tensor* sampled_gts,
Tensor* sampled_max_overlap) {
int fg_num = fg_inds.size();
int bg_num = bg_inds.size();
Tensor fg_inds_t, bg_inds_t, gt_box_inds_t, gt_label_inds_t;
......@@ -264,6 +290,13 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context,
bg_labels.mutable_data<int>({bg_num}, context.GetPlace());
math::set_constant(context, &bg_labels, 0);
Concat<int>(context, fg_labels, bg_labels, sampled_labels);
Tensor fg_max_overlap, bg_max_overlap;
fg_max_overlap.mutable_data<T>({fg_num}, context.GetPlace());
CPUGather<T>(context, max_overlap, fg_inds_t, &fg_max_overlap);
bg_max_overlap.mutable_data<T>({bg_num}, context.GetPlace());
CPUGather<T>(context, max_overlap, bg_inds_t, &bg_max_overlap);
Concat<T>(context, fg_max_overlap, bg_max_overlap, sampled_max_overlap);
}
template <typename T>
......@@ -274,43 +307,58 @@ std::vector<Tensor> SampleRoisForOneImage(
const float fg_thresh, const float bg_thresh_hi, const float bg_thresh_lo,
const std::vector<float>& bbox_reg_weights, const int class_nums,
std::minstd_rand engine, bool use_random, bool is_cascade_rcnn,
bool is_cls_agnostic) {
bool is_cls_agnostic, const Tensor& max_overlap) {
// 1.1 map to original image
auto im_scale = im_info.data<T>()[2];
Tensor rpn_rois;
rpn_rois.mutable_data<T>(rpn_rois_in.dims(), context.GetPlace());
const T* rpn_rois_in_dt = rpn_rois_in.data<T>();
T* rpn_rois_dt = rpn_rois.data<T>();
int gt_num = gt_boxes.dims()[0] * 4;
for (int i = 0; i < rpn_rois.numel(); ++i) {
if (i < gt_num && is_cascade_rcnn) {
rpn_rois_dt[i] = rpn_rois_in_dt[i];
rpn_rois_dt[i] = rpn_rois_in_dt[i] / im_scale;
}
int proposals_num = 1;
if (is_cascade_rcnn) {
Tensor keep;
FilterRoIs<T>(context, rpn_rois, max_overlap, &keep);
Tensor roi_filter;
// Tensor box_filter;
if (keep.numel() == 0) {
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
roi_filter.mutable_data<T>({proposals_num, kBoxDim}, context.GetPlace());
set_zero(context, &roi_filter, static_cast<T>(0));
} else {
rpn_rois_dt[i] = rpn_rois_in_dt[i] / im_scale;
proposals_num = keep.numel();
roi_filter.mutable_data<T>({proposals_num, kBoxDim}, context.GetPlace());
CPUGather<T>(context, rpn_rois, keep, &roi_filter);
}
T* roi_filter_dt = roi_filter.data<T>();
memcpy(rpn_rois_dt, roi_filter_dt, roi_filter.numel() * sizeof(T));
rpn_rois.Resize(roi_filter.dims());
} else {
proposals_num = rpn_rois.dims()[0];
}
// 1.2 compute overlaps
int proposals_num = rpn_rois.dims()[0];
if (!is_cascade_rcnn) {
proposals_num += gt_boxes.dims()[0];
}
proposals_num += gt_boxes.dims()[0];
Tensor proposal_to_gt_overlaps;
proposal_to_gt_overlaps.mutable_data<T>({proposals_num, gt_boxes.dims()[0]},
context.GetPlace());
Tensor boxes;
boxes.mutable_data<T>({proposals_num, kBoxDim}, context.GetPlace());
if (!is_cascade_rcnn) {
Concat<T>(context, gt_boxes, rpn_rois, &boxes);
} else {
T* boxes_dt = boxes.data<T>();
for (int i = 0; i < boxes.numel(); ++i) {
boxes_dt[i] = rpn_rois_dt[i];
}
}
Concat<T>(context, gt_boxes, rpn_rois, &boxes);
BboxOverlaps<T>(boxes, gt_boxes, &proposal_to_gt_overlaps);
Tensor proposal_with_max_overlap;
proposal_with_max_overlap.mutable_data<T>({proposals_num},
context.GetPlace());
MaxIoU<T>(proposal_to_gt_overlaps, &proposal_with_max_overlap);
// Generate proposal index
std::vector<std::vector<int>> fg_bg_gt =
SampleFgBgGt<T>(context, &proposal_to_gt_overlaps, is_crowd,
......@@ -321,7 +369,7 @@ std::vector<Tensor> SampleRoisForOneImage(
std::vector<int> mapped_gt_inds = fg_bg_gt[2]; // mapped_gt_labels
// Gather boxes and labels
Tensor sampled_boxes, sampled_labels, sampled_gts;
Tensor sampled_boxes, sampled_labels, sampled_gts, sampled_max_overlap;
int fg_num = fg_inds.size();
int bg_num = bg_inds.size();
int boxes_num = fg_num + bg_num;
......@@ -329,9 +377,11 @@ std::vector<Tensor> SampleRoisForOneImage(
sampled_boxes.mutable_data<T>(bbox_dim, context.GetPlace());
sampled_labels.mutable_data<int>({boxes_num}, context.GetPlace());
sampled_gts.mutable_data<T>({fg_num, kBoxDim}, context.GetPlace());
GatherBoxesLabels<T>(context, boxes, gt_boxes, gt_classes, fg_inds, bg_inds,
mapped_gt_inds, &sampled_boxes, &sampled_labels,
&sampled_gts);
sampled_max_overlap.mutable_data<T>({boxes_num}, context.GetPlace());
GatherBoxesLabels<T>(context, boxes, proposal_with_max_overlap, gt_boxes,
gt_classes, fg_inds, bg_inds, mapped_gt_inds,
&sampled_boxes, &sampled_labels, &sampled_gts,
&sampled_max_overlap);
// Compute targets
Tensor bbox_targets_single;
......@@ -390,6 +440,7 @@ std::vector<Tensor> SampleRoisForOneImage(
res.emplace_back(bbox_targets);
res.emplace_back(bbox_inside_weights);
res.emplace_back(bbox_outside_weights);
res.emplace_back(sampled_max_overlap);
return res;
}
......@@ -409,6 +460,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
auto* bbox_inside_weights = context.Output<LoDTensor>("BboxInsideWeights");
auto* bbox_outside_weights =
context.Output<LoDTensor>("BboxOutsideWeights");
auto* max_overlap_with_gt = context.Output<LoDTensor>("MaxOverlapWithGT");
int batch_size_per_im = context.Attr<int>("batch_size_per_im");
float fg_fraction = context.Attr<float>("fg_fraction");
......@@ -446,16 +498,21 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
"received level of LoD is [%d], LoD is [%s].",
gt_boxes->lod().size(), gt_boxes->lod()));
int64_t n = static_cast<int64_t>(rpn_rois->lod().back().size() - 1);
rois->mutable_data<T>({n * batch_size_per_im, kBoxDim}, context.GetPlace());
labels_int32->mutable_data<int>({n * batch_size_per_im, 1},
context.GetPlace());
bbox_targets->mutable_data<T>({n * batch_size_per_im, kBoxDim * class_nums},
int64_t rois_num = rpn_rois->dims()[0];
int64_t gts_num = gt_boxes->dims()[0];
int64_t init_num =
is_cascade_rcnn ? rois_num + gts_num : n * batch_size_per_im;
rois->mutable_data<T>({init_num, kBoxDim}, context.GetPlace());
labels_int32->mutable_data<int>({init_num, 1}, context.GetPlace());
bbox_targets->mutable_data<T>({init_num, kBoxDim * class_nums},
context.GetPlace());
bbox_inside_weights->mutable_data<T>(
{n * batch_size_per_im, kBoxDim * class_nums}, context.GetPlace());
bbox_outside_weights->mutable_data<T>(
{n * batch_size_per_im, kBoxDim * class_nums}, context.GetPlace());
bbox_inside_weights->mutable_data<T>({init_num, kBoxDim * class_nums},
context.GetPlace());
bbox_outside_weights->mutable_data<T>({init_num, kBoxDim * class_nums},
context.GetPlace());
max_overlap_with_gt->Resize({init_num});
max_overlap_with_gt->mutable_data<T>(context.GetPlace());
std::random_device rnd;
std::minstd_rand engine;
......@@ -486,25 +543,36 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
Tensor gt_boxes_slice =
gt_boxes->Slice(gt_boxes_lod[i], gt_boxes_lod[i + 1]);
Tensor im_info_slice = im_info->Slice(i, i + 1);
Tensor max_overlap_slice;
if (is_cascade_rcnn) {
auto* max_overlap = context.Input<Tensor>("MaxOverlap");
max_overlap_slice =
max_overlap->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]);
} else {
max_overlap_slice.mutable_data<T>({rpn_rois_slice.dims()[0]},
context.GetPlace());
}
std::vector<Tensor> tensor_output = SampleRoisForOneImage<T>(
dev_ctx, rpn_rois_slice, gt_classes_slice, is_crowd_slice,
gt_boxes_slice, im_info_slice, batch_size_per_im, fg_fraction,
fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums,
engine, use_random, is_cascade_rcnn, is_cls_agnostic);
engine, use_random, is_cascade_rcnn, is_cls_agnostic,
max_overlap_slice);
Tensor sampled_rois = tensor_output[0];
Tensor sampled_labels_int32 = tensor_output[1];
Tensor sampled_bbox_targets = tensor_output[2];
Tensor sampled_bbox_inside_weights = tensor_output[3];
Tensor sampled_bbox_outside_weights = tensor_output[4];
Tensor sampled_max_overlap = tensor_output[5];
AppendRois<T>(rois, kBoxDim * num_rois, &sampled_rois);
AppendRois<int>(labels_int32, num_rois, &sampled_labels_int32);
AppendRois<T>(bbox_targets, kBoxDim * num_rois * class_nums,
&sampled_bbox_targets);
AppendRois<T>(bbox_inside_weights, kBoxDim * num_rois * class_nums,
&sampled_bbox_inside_weights);
AppendRois<T>(bbox_outside_weights, kBoxDim * num_rois * class_nums,
int64_t offset = kBoxDim * num_rois * class_nums;
AppendRois<T>(bbox_targets, offset, &sampled_bbox_targets);
AppendRois<T>(bbox_inside_weights, offset, &sampled_bbox_inside_weights);
AppendRois<T>(bbox_outside_weights, offset,
&sampled_bbox_outside_weights);
AppendRois<T>(max_overlap_with_gt, num_rois, &sampled_max_overlap);
num_rois += sampled_rois.dims()[0];
lod0.emplace_back(num_rois);
......@@ -521,6 +589,8 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
bbox_targets->Resize({num_rois, kBoxDim * class_nums});
bbox_inside_weights->Resize({num_rois, kBoxDim * class_nums});
bbox_outside_weights->Resize({num_rois, kBoxDim * class_nums});
max_overlap_with_gt->Resize({num_rois});
max_overlap_with_gt->set_lod(lod);
}
};
......@@ -550,6 +620,12 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), This input is a 2D Tensor with shape [B, 3]. "
"B is the number of input images, "
"each element consists of im_height, im_width, im_scale.");
AddInput("MaxOverlap",
"(LoDTensor), This input is a 1D LoDTensor with shape [N]."
"N is the number of Input(RpnRois), "
"each element is the maximum overlap between "
"the proposal RoI and ground-truth.")
.AsDispensable();
AddOutput(
"Rois",
......@@ -573,6 +649,12 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 4 * "
"class_nums], "
"each element indicates whether a box should contribute to loss.");
AddOutput("MaxOverlapWithGT",
"(LoDTensor), This output is a 1D LoDTensor with shape [P], "
"each element indicates the maxoverlap "
"between output RoIs and ground-truth. "
"The output RoIs may include ground-truth "
"and the output maxoverlap may contain 1.");
AddAttr<int>("batch_size_per_im", "Batch size of rois per images.");
AddAttr<float>("fg_fraction",
......
......@@ -2601,7 +2601,9 @@ def generate_proposal_labels(rpn_rois,
class_nums=None,
use_random=True,
is_cls_agnostic=False,
is_cascade_rcnn=False):
is_cascade_rcnn=False,
max_overlap=None,
return_max_overlap=False):
"""
**Generate Proposal Labels of Faster-RCNN**
......@@ -2638,25 +2640,29 @@ def generate_proposal_labels(rpn_rois,
use_random(bool): Use random sampling to choose foreground and background boxes.
is_cls_agnostic(bool): bbox regression use class agnostic simply which only represent fg and bg boxes.
is_cascade_rcnn(bool): it will filter some bbox crossing the image's boundary when setting True.
max_overlap(Variable): Maximum overlap between each proposal box and ground-truth.
return_max_overlap(bool): Whether return the maximum overlap between each sampled RoI and ground-truth.
Returns:
tuple:
A tuple with format``(rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights)``.
A tuple with format``(rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, max_overlap)``.
- **rois**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 4]``. The data type is the same as ``rpn_rois``.
- **labels_int32**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 1]``. The data type must be int32.
- **bbox_targets**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 4 * class_num]``. The regression targets of all RoIs. The data type is the same as ``rpn_rois``.
- **bbox_inside_weights**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 4 * class_num]``. The weights of foreground boxes' regression loss. The data type is the same as ``rpn_rois``.
- **bbox_outside_weights**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 4 * class_num]``. The weights of regression loss. The data type is the same as ``rpn_rois``.
- **max_overlap**: 1-D LoDTensor with shape ``[P]``. P is the number of output ``rois``. The maximum overlap between each sampled RoI and ground-truth.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
paddle.enable_static()
rpn_rois = fluid.data(name='rpn_rois', shape=[None, 4], dtype='float32')
gt_classes = fluid.data(name='gt_classes', shape=[None, 1], dtype='float32')
is_crowd = fluid.data(name='is_crowd', shape=[None, 1], dtype='float32')
gt_classes = fluid.data(name='gt_classes', shape=[None, 1], dtype='int32')
is_crowd = fluid.data(name='is_crowd', shape=[None, 1], dtype='int32')
gt_boxes = fluid.data(name='gt_boxes', shape=[None, 4], dtype='float32')
im_info = fluid.data(name='im_info', shape=[None, 3], dtype='float32')
rois, labels, bbox, inside_weights, outside_weights = fluid.layers.generate_proposal_labels(
......@@ -2673,6 +2679,8 @@ def generate_proposal_labels(rpn_rois,
'generate_proposal_labels')
check_variable_and_dtype(is_crowd, 'is_crowd', ['int32'],
'generate_proposal_labels')
if is_cascade_rcnn:
assert max_overlap is not None, "Input max_overlap of generate_proposal_labels should not be None if is_cascade_rcnn is True"
rois = helper.create_variable_for_type_inference(dtype=rpn_rois.dtype)
labels_int32 = helper.create_variable_for_type_inference(
......@@ -2683,22 +2691,28 @@ def generate_proposal_labels(rpn_rois,
dtype=rpn_rois.dtype)
bbox_outside_weights = helper.create_variable_for_type_inference(
dtype=rpn_rois.dtype)
max_overlap_with_gt = helper.create_variable_for_type_inference(
dtype=rpn_rois.dtype)
inputs = {
'RpnRois': rpn_rois,
'GtClasses': gt_classes,
'IsCrowd': is_crowd,
'GtBoxes': gt_boxes,
'ImInfo': im_info,
}
if max_overlap is not None:
inputs['MaxOverlap'] = max_overlap
helper.append_op(
type="generate_proposal_labels",
inputs={
'RpnRois': rpn_rois,
'GtClasses': gt_classes,
'IsCrowd': is_crowd,
'GtBoxes': gt_boxes,
'ImInfo': im_info
},
inputs=inputs,
outputs={
'Rois': rois,
'LabelsInt32': labels_int32,
'BboxTargets': bbox_targets,
'BboxInsideWeights': bbox_inside_weights,
'BboxOutsideWeights': bbox_outside_weights
'BboxOutsideWeights': bbox_outside_weights,
'MaxOverlapWithGT': max_overlap_with_gt
},
attrs={
'batch_size_per_im': batch_size_per_im,
......@@ -2718,7 +2732,10 @@ def generate_proposal_labels(rpn_rois,
bbox_targets.stop_gradient = True
bbox_inside_weights.stop_gradient = True
bbox_outside_weights.stop_gradient = True
max_overlap_with_gt.stop_gradient = True
if return_max_overlap:
return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, max_overlap_with_gt
return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights
......
......@@ -289,40 +289,39 @@ class TestAnchorGenerator(unittest.TestCase):
class TestGenerateProposalLabels(unittest.TestCase):
def check_out(self, outs):
rois = outs[0]
labels_int32 = outs[1]
bbox_targets = outs[2]
bbox_inside_weights = outs[3]
bbox_outside_weights = outs[4]
assert rois.shape[1] == 4
assert rois.shape[0] == labels_int32.shape[0]
assert rois.shape[0] == bbox_targets.shape[0]
assert rois.shape[0] == bbox_inside_weights.shape[0]
assert rois.shape[0] == bbox_outside_weights.shape[0]
assert bbox_targets.shape[1] == 4 * self.class_nums
assert bbox_inside_weights.shape[1] == 4 * self.class_nums
assert bbox_outside_weights.shape[1] == 4 * self.class_nums
if len(outs) == 6:
max_overlap_with_gt = outs[5]
assert max_overlap_with_gt.shape[0] == rois.shape[0]
def test_generate_proposal_labels(self):
program = Program()
with program_guard(program):
rpn_rois = layers.data(
name='rpn_rois',
shape=[4, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
gt_classes = layers.data(
name='gt_classes',
shape=[6],
dtype='int32',
lod_level=1,
append_batch_size=False)
is_crowd = layers.data(
name='is_crowd',
shape=[6],
dtype='int32',
lod_level=1,
append_batch_size=False)
gt_boxes = layers.data(
name='gt_boxes',
shape=[6, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
im_info = layers.data(
name='im_info',
shape=[1, 3],
dtype='float32',
lod_level=1,
append_batch_size=False)
class_nums = 5
rpn_rois = fluid.data(
name='rpn_rois', shape=[4, 4], dtype='float32', lod_level=1)
gt_classes = fluid.data(
name='gt_classes', shape=[6], dtype='int32', lod_level=1)
is_crowd = fluid.data(
name='is_crowd', shape=[6], dtype='int32', lod_level=1)
gt_boxes = fluid.data(
name='gt_boxes', shape=[6, 4], dtype='float32', lod_level=1)
im_info = fluid.data(name='im_info', shape=[1, 3], dtype='float32')
max_overlap = fluid.data(
name='max_overlap', shape=[4], dtype='float32', lod_level=1)
self.class_nums = 5
outs = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
gt_classes=gt_classes,
......@@ -335,20 +334,27 @@ class TestGenerateProposalLabels(unittest.TestCase):
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=class_nums)
class_nums=self.class_nums)
outs_1 = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
gt_classes=gt_classes,
is_crowd=is_crowd,
gt_boxes=gt_boxes,
im_info=im_info,
batch_size_per_im=2,
fg_fraction=0.5,
fg_thresh=0.5,
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=self.class_nums,
is_cascade_rcnn=True,
max_overlap=max_overlap,
return_max_overlap=True)
self.check_out(outs)
self.check_out(outs_1)
rois = outs[0]
labels_int32 = outs[1]
bbox_targets = outs[2]
bbox_inside_weights = outs[3]
bbox_outside_weights = outs[4]
assert rois.shape[1] == 4
assert rois.shape[0] == labels_int32.shape[0]
assert rois.shape[0] == bbox_targets.shape[0]
assert rois.shape[0] == bbox_inside_weights.shape[0]
assert rois.shape[0] == bbox_outside_weights.shape[0]
assert bbox_targets.shape[1] == 4 * class_nums
assert bbox_inside_weights.shape[1] == 4 * class_nums
assert bbox_outside_weights.shape[1] == 4 * class_nums
class TestGenerateMaskLabels(unittest.TestCase):
......
......@@ -22,66 +22,91 @@ import paddle.fluid as fluid
from op_test import OpTest
def generate_proposal_labels_in_python(
rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, batch_size_per_im,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights,
class_nums, use_random, is_cls_agnostic, is_cascade_rcnn):
def generate_proposal_labels_in_python(rpn_rois,
gt_classes,
is_crowd,
gt_boxes,
im_info,
batch_size_per_im,
fg_fraction,
fg_thresh,
bg_thresh_hi,
bg_thresh_lo,
bbox_reg_weights,
class_nums,
use_random,
is_cls_agnostic,
is_cascade_rcnn,
max_overlaps=None):
rois = []
labels_int32 = []
bbox_targets = []
bbox_inside_weights = []
bbox_outside_weights = []
max_overlap_with_gt = []
lod = []
assert len(rpn_rois) == len(
im_info), 'batch size of rpn_rois and ground_truth is not matched'
for im_i in range(len(im_info)):
max_overlap = max_overlaps[im_i] if is_cascade_rcnn else None
frcn_blobs = _sample_rois(
rpn_rois[im_i], gt_classes[im_i], is_crowd[im_i], gt_boxes[im_i],
im_info[im_i], batch_size_per_im, fg_fraction, fg_thresh,
bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums,
use_random, is_cls_agnostic, is_cascade_rcnn)
use_random, is_cls_agnostic, is_cascade_rcnn, max_overlap)
lod.append(frcn_blobs['rois'].shape[0])
rois.append(frcn_blobs['rois'])
labels_int32.append(frcn_blobs['labels_int32'])
bbox_targets.append(frcn_blobs['bbox_targets'])
bbox_inside_weights.append(frcn_blobs['bbox_inside_weights'])
bbox_outside_weights.append(frcn_blobs['bbox_outside_weights'])
max_overlap_with_gt.append(frcn_blobs['max_overlap'])
return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, lod
return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, max_overlap_with_gt, lod
def filter_roi(rois, max_overlap):
ws = rois[:, 2] - rois[:, 0] + 1
hs = rois[:, 3] - rois[:, 1] + 1
keep = np.where((ws > 0) & (hs > 0) & (max_overlap < 1.0))[0]
if len(keep) > 0:
return rois[keep, :]
return np.zeros((1, 4)).astype('float32')
def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bg_thresh_lo, bbox_reg_weights, class_nums, use_random,
is_cls_agnostic, is_cascade_rcnn):
is_cls_agnostic, is_cascade_rcnn, max_overlap):
rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
# Roidb
im_scale = im_info[2]
inv_im_scale = 1. / im_scale
if is_cascade_rcnn:
rpn_rois = rpn_rois[len(gt_boxes):, :]
rpn_rois = rpn_rois * inv_im_scale
if is_cascade_rcnn:
rpn_rois = filter_roi(rpn_rois, max_overlap)
boxes = np.vstack([gt_boxes, rpn_rois])
gt_overlaps = np.zeros((boxes.shape[0], class_nums))
box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32)
if len(gt_boxes) > 0:
proposal_to_gt_overlaps = _bbox_overlaps(boxes, gt_boxes)
overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
overlaps_max = proposal_to_gt_overlaps.max(axis=1)
# Boxes which with non-zero overlap with gt boxes
overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
overlapped_boxes_ind]]
gt_overlaps[overlapped_boxes_ind,
overlapped_boxes_gt_classes] = overlaps_max[
overlapped_boxes_ind]
box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[
overlapped_boxes_ind]
proposal_to_gt_overlaps = _bbox_overlaps(boxes, gt_boxes)
overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
overlaps_max = proposal_to_gt_overlaps.max(axis=1)
# Boxes which with non-zero overlap with gt boxes
overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
overlapped_boxes_ind]]
gt_overlaps[overlapped_boxes_ind,
overlapped_boxes_gt_classes] = overlaps_max[
overlapped_boxes_ind]
box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[
overlapped_boxes_ind]
crowd_ind = np.where(is_crowd)[0]
gt_overlaps[crowd_ind] = -1.0
......@@ -90,11 +115,6 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
if is_cascade_rcnn:
# Cascade RCNN Decode Filter
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
keep = np.where((ws > 0) & (hs > 0))[0]
boxes = boxes[keep]
max_overlaps = max_overlaps[keep]
fg_inds = np.where(max_overlaps >= fg_thresh)[0]
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
bg_thresh_lo))[0]
......@@ -125,6 +145,7 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
sampled_labels = max_classes[keep_inds]
sampled_labels[fg_rois_per_this_image:] = 0
sampled_boxes = boxes[keep_inds]
sampled_max_overlap = max_overlaps[keep_inds]
sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]]
sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0]
bbox_label_targets = _compute_targets(sampled_boxes, sampled_gts,
......@@ -142,7 +163,8 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
labels_int32=sampled_labels,
bbox_targets=bbox_targets,
bbox_inside_weights=bbox_inside_weights,
bbox_outside_weights=bbox_outside_weights)
bbox_outside_weights=bbox_outside_weights,
max_overlap=sampled_max_overlap)
return frcn_blobs
......@@ -226,9 +248,9 @@ class TestGenerateProposalLabelsOp(OpTest):
def set_data(self):
#self.use_random = False
self.init_use_random()
self.init_test_cascade()
self.init_test_params()
self.init_test_input()
self.init_test_cascade()
self.init_test_output()
self.inputs = {
......@@ -236,8 +258,12 @@ class TestGenerateProposalLabelsOp(OpTest):
'GtClasses': (self.gt_classes[0], self.gts_lod),
'IsCrowd': (self.is_crowd[0], self.gts_lod),
'GtBoxes': (self.gt_boxes[0], self.gts_lod),
'ImInfo': self.im_info
'ImInfo': self.im_info,
}
if self.max_overlaps is not None:
self.inputs['MaxOverlap'] = (self.max_overlaps[0],
self.rpn_rois_lod)
self.attrs = {
'batch_size_per_im': self.batch_size_per_im,
'fg_fraction': self.fg_fraction,
......@@ -256,6 +282,7 @@ class TestGenerateProposalLabelsOp(OpTest):
'BboxTargets': (self.bbox_targets, [self.lod]),
'BboxInsideWeights': (self.bbox_inside_weights, [self.lod]),
'BboxOutsideWeights': (self.bbox_outside_weights, [self.lod]),
'MaxOverlapWithGT': (self.max_overlap_with_gt, [self.lod]),
}
def test_check_output(self):
......@@ -267,12 +294,13 @@ class TestGenerateProposalLabelsOp(OpTest):
def init_test_cascade(self, ):
self.is_cascade_rcnn = False
self.max_overlaps = None
def init_use_random(self):
self.use_random = False
def init_test_params(self):
self.batch_size_per_im = 512
self.batch_size_per_im = 100
self.fg_fraction = 0.25
self.fg_thresh = 0.5
self.bg_thresh_hi = 0.5
......@@ -284,7 +312,7 @@ class TestGenerateProposalLabelsOp(OpTest):
def init_test_input(self):
np.random.seed(0)
gt_nums = 6 # Keep same with batch_size_per_im for unittest
proposal_nums = 2000 if not self.is_cascade_rcnn else 512 #self.batch_size_per_im - gt_nums
proposal_nums = 200
images_shape = [[64, 64]]
self.im_info = np.ones((len(images_shape), 3)).astype(np.float32)
for i in range(len(images_shape)):
......@@ -301,24 +329,16 @@ class TestGenerateProposalLabelsOp(OpTest):
self.gt_boxes = [gt['boxes'] for gt in ground_truth]
self.is_crowd = [gt['is_crowd'] for gt in ground_truth]
if self.is_cascade_rcnn:
rpn_rois_new = []
for im_i in range(len(self.im_info)):
gt_boxes = self.gt_boxes[im_i]
rpn_rois = np.vstack(
[gt_boxes, self.rpn_rois[im_i][len(gt_boxes):, :]])
rpn_rois_new.append(rpn_rois)
self.rpn_rois = rpn_rois_new
def init_test_output(self):
self.rois, self.labels_int32, self.bbox_targets, \
self.bbox_inside_weights, self.bbox_outside_weights, \
self.max_overlap_with_gt, \
self.lod = generate_proposal_labels_in_python(
self.rpn_rois, self.gt_classes, self.is_crowd, self.gt_boxes, self.im_info,
self.batch_size_per_im, self.fg_fraction,
self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo,
self.bbox_reg_weights, self.class_nums, self.use_random,
self.is_cls_agnostic, self.is_cascade_rcnn
self.is_cls_agnostic, self.is_cascade_rcnn, self.max_overlaps
)
self.rois = np.vstack(self.rois)
self.labels_int32 = np.hstack(self.labels_int32)
......@@ -326,11 +346,18 @@ class TestGenerateProposalLabelsOp(OpTest):
self.bbox_targets = np.vstack(self.bbox_targets)
self.bbox_inside_weights = np.vstack(self.bbox_inside_weights)
self.bbox_outside_weights = np.vstack(self.bbox_outside_weights)
self.max_overlap_with_gt = np.vstack(self.max_overlap_with_gt)
class TestCascade(TestGenerateProposalLabelsOp):
def init_test_cascade(self):
self.is_cascade_rcnn = True
roi_num = len(self.rpn_rois[0])
self.max_overlaps = []
max_overlap = np.random.rand(roi_num).astype('float32')
# Make GT samples with overlap = 1
max_overlap[max_overlap > 0.9] = 1.
self.max_overlaps.append(max_overlap)
class TestUseRandom(TestGenerateProposalLabelsOp):
......@@ -389,6 +416,15 @@ class TestOnlyGT(TestCascade):
self.rpn_rois_lod = self.gts_lod
class TestOnlyGT2(TestCascade):
def init_test_cascade(self):
self.is_cascade_rcnn = True
roi_num = len(self.rpn_rois[0])
self.max_overlaps = []
max_overlap = np.ones(roi_num).astype('float32')
self.max_overlaps.append(max_overlap)
def _generate_proposals(images_shape, proposal_nums):
rpn_rois = []
rpn_rois_lod = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册