From 9557cc218d3d71947d7204dce6d711126eb80ad0 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 4 Sep 2018 11:54:38 +0800 Subject: [PATCH] Refine and fix some code for faster-rcnn. (#13135) * Fix bug in generate_proposals_op. * Fix data type for RoIs. * Refine and fix rpn_target_assign_op. * Add the missing file bbox_util.h * Rename BoxEncoder to BoxToDelta --- paddle/fluid/operators/detection/bbox_util.h | 66 ++++ .../detection/generate_proposal_labels_op.cc | 39 +-- .../detection/generate_proposals_op.cc | 5 +- .../detection/rpn_target_assign_op.cc | 291 +++++++++++------- paddle/fluid/operators/roi_pool_op.cu | 12 +- paddle/fluid/operators/roi_pool_op.h | 4 +- python/paddle/fluid/layers/detection.py | 39 ++- python/paddle/fluid/tests/test_detection.py | 18 +- .../test_generate_proposal_labels.py | 4 +- .../fluid/tests/unittests/test_roi_pool_op.py | 4 +- .../unittests/test_rpn_target_assign_op.py | 131 +++++--- 11 files changed, 388 insertions(+), 225 deletions(-) create mode 100644 paddle/fluid/operators/detection/bbox_util.h diff --git a/paddle/fluid/operators/detection/bbox_util.h b/paddle/fluid/operators/detection/bbox_util.h new file mode 100644 index 00000000000..0dee1781623 --- /dev/null +++ b/paddle/fluid/operators/detection/bbox_util.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { + +/* + * transform that computes target bounding-box regression deltas + * given proposal boxes and ground-truth boxes. + */ +template +inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes, + const framework::Tensor& gt_boxes, const T* weights, + const bool normalized, framework::Tensor* box_delta) { + auto ex_boxes_et = framework::EigenTensor::From(ex_boxes); + auto gt_boxes_et = framework::EigenTensor::From(gt_boxes); + auto trg = framework::EigenTensor::From(*box_delta); + T ex_w, ex_h, ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y; + for (int64_t i = 0; i < box_num; ++i) { + ex_w = ex_boxes_et(i, 2) - ex_boxes_et(i, 0) + (normalized == false); + ex_h = ex_boxes_et(i, 3) - ex_boxes_et(i, 1) + (normalized == false); + ex_ctr_x = ex_boxes_et(i, 0) + 0.5 * ex_w; + ex_ctr_y = ex_boxes_et(i, 1) + 0.5 * ex_h; + + gt_w = gt_boxes_et(i, 2) - gt_boxes_et(i, 0) + (normalized == false); + gt_h = gt_boxes_et(i, 3) - gt_boxes_et(i, 1) + (normalized == false); + gt_ctr_x = gt_boxes_et(i, 0) + 0.5 * gt_w; + gt_ctr_y = gt_boxes_et(i, 1) + 0.5 * gt_h; + + trg(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w; + trg(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h; + trg(i, 2) = std::log(gt_w / ex_w); + trg(i, 3) = std::log(gt_h / ex_h); + + if (weights) { + trg(i, 0) = trg(i, 0) / weights[0]; + trg(i, 1) = trg(i, 1) / weights[1]; + trg(i, 2) = trg(i, 2) / weights[2]; + trg(i, 3) = trg(i, 3) / weights[3]; + } + } +} + +template +void Gather(const T* in, const int in_stride, const int* index, const int num, + T* out) { + const int stride_bytes = in_stride * sizeof(T); + for (int i = 0; i < num; ++i) { + int id = index[i]; + memcpy(out + i * in_stride, in + id * in_stride, stride_bytes); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index 0571c46f6be..be06dc19743 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/math_function.h" @@ -133,31 +134,6 @@ void BboxOverlaps(const Tensor& r_boxes, const Tensor& c_boxes, } } -template -void BoxToDelta(int box_num, const Tensor& ex_boxes, const Tensor& gt_boxes, - const std::vector& weights, Tensor* box_delta) { - auto ex_boxes_et = framework::EigenTensor::From(ex_boxes); - auto gt_boxes_et = framework::EigenTensor::From(gt_boxes); - auto box_delta_et = framework::EigenTensor::From(*box_delta); - T ex_w, ex_h, ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y; - for (int64_t i = 0; i < box_num; ++i) { - ex_w = ex_boxes_et(i, 2) - ex_boxes_et(i, 0) + 1; - ex_h = ex_boxes_et(i, 3) - ex_boxes_et(i, 1) + 1; - ex_ctr_x = ex_boxes_et(i, 0) + 0.5 * ex_w; - ex_ctr_y = ex_boxes_et(i, 1) + 0.5 * ex_h; - - gt_w = gt_boxes_et(i, 2) - gt_boxes_et(i, 0) + 1; - gt_h = gt_boxes_et(i, 3) - gt_boxes_et(i, 1) + 1; - gt_ctr_x = gt_boxes_et(i, 0) + 0.5 * gt_w; - gt_ctr_y = gt_boxes_et(i, 1) + 0.5 * gt_h; - - box_delta_et(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0]; - box_delta_et(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1]; - box_delta_et(i, 2) = log(gt_w / ex_w) / ex_w / weights[2]; - box_delta_et(i, 3) = log(gt_h / ex_h) / ex_h / weights[3]; - } -} - template std::vector> SampleFgBgGt( const platform::CPUDeviceContext& context, Tensor* iou, @@ -243,12 +219,11 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context, Tensor* sampled_labels, Tensor* sampled_gts) { int fg_num = fg_inds.size(); int bg_num = bg_inds.size(); - int gt_num = fg_num + bg_num; Tensor fg_inds_t, bg_inds_t, gt_box_inds_t, gt_label_inds_t; int* fg_inds_data = fg_inds_t.mutable_data({fg_num}, context.GetPlace()); int* bg_inds_data = bg_inds_t.mutable_data({bg_num}, context.GetPlace()); int* gt_box_inds_data = - gt_box_inds_t.mutable_data({gt_num}, context.GetPlace()); + gt_box_inds_t.mutable_data({fg_num}, context.GetPlace()); int* gt_label_inds_data = gt_label_inds_t.mutable_data({fg_num}, context.GetPlace()); std::copy(fg_inds.begin(), fg_inds.end(), fg_inds_data); @@ -303,18 +278,20 @@ std::vector SampleRoisForOneImage( // Gather boxes and labels Tensor sampled_boxes, sampled_labels, sampled_gts; - int boxes_num = fg_inds.size() + bg_inds.size(); + int fg_num = fg_inds.size(); + int bg_num = bg_inds.size(); + int boxes_num = fg_num + bg_num; framework::DDim bbox_dim({boxes_num, kBoxDim}); sampled_boxes.mutable_data(bbox_dim, context.GetPlace()); sampled_labels.mutable_data({boxes_num}, context.GetPlace()); - sampled_gts.mutable_data(bbox_dim, context.GetPlace()); + sampled_gts.mutable_data({fg_num, kBoxDim}, context.GetPlace()); GatherBoxesLabels(context, boxes, *gt_boxes, *gt_classes, fg_inds, bg_inds, gt_inds, &sampled_boxes, &sampled_labels, &sampled_gts); // Compute targets Tensor bbox_targets_single; bbox_targets_single.mutable_data(bbox_dim, context.GetPlace()); - BoxToDelta(boxes_num, sampled_boxes, sampled_gts, bbox_reg_weights, + BoxToDelta(fg_num, sampled_boxes, sampled_gts, nullptr, false, &bbox_targets_single); // Scale rois @@ -427,7 +404,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel { auto rpn_rois_lod = rpn_rois->lod().back(); auto gt_classes_lod = gt_classes->lod().back(); auto gt_boxes_lod = gt_boxes->lod().back(); - for (size_t i = 0; i < n; ++i) { + for (int i = 0; i < n; ++i) { Tensor rpn_rois_slice = rpn_rois->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]); Tensor gt_classes_slice = diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index fcdcafae727..ebe6830eccd 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -311,8 +311,7 @@ class GenerateProposalsKernel : public framework::OpKernel { rpn_rois->mutable_data({bbox_deltas->numel() / 4, 4}, context.GetPlace()); - rpn_roi_probs->mutable_data({scores->numel() / 4, 1}, - context.GetPlace()); + rpn_roi_probs->mutable_data({scores->numel(), 1}, context.GetPlace()); Tensor bbox_deltas_swap, scores_swap; bbox_deltas_swap.mutable_data({num, h_bbox, w_bbox, c_bbox}, @@ -421,7 +420,7 @@ class GenerateProposalsKernel : public framework::OpKernel { CPUGather(ctx, proposals, keep, &bbox_sel); CPUGather(ctx, scores_sel, keep, &scores_filter); if (nms_thresh <= 0) { - return std::make_pair(bbox_sel, scores_sel); + return std::make_pair(bbox_sel, scores_filter); } Tensor keep_nms = NMS(ctx, &bbox_sel, &scores_filter, nms_thresh, eta); diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 177ff7cf187..88757f25cd9 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { @@ -46,156 +47,219 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { auto in_dims = ctx->GetInputDim("DistMat"); PADDLE_ENFORCE_EQ(in_dims.size(), 2, "The rank of Input(DistMat) must be 2."); + + ctx->SetOutputDim("LocationIndex", {-1}); + ctx->SetOutputDim("ScoreIndex", {-1}); + ctx->SetOutputDim("TargetLabel", {-1, 1}); + ctx->SetOutputDim("TargetBBox", {-1, 4}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.Input("DistMat")->type()), + platform::CPUPlace()); } }; template class RpnTargetAssignKernel : public framework::OpKernel { public: + void Compute(const framework::ExecutionContext& context) const override { + auto* anchor_t = context.Input("Anchor"); // (H*W*A) * 4 + auto* gt_bbox_t = context.Input("GtBox"); + auto* dist_t = context.Input("DistMat"); + + auto* loc_index_t = context.Output("LocationIndex"); + auto* score_index_t = context.Output("ScoreIndex"); + auto* tgt_bbox_t = context.Output("TargetBBox"); + auto* tgt_lbl_t = context.Output("TargetLabel"); + + auto lod = dist_t->lod().back(); + int64_t batch_num = static_cast(lod.size() - 1); + int64_t anchor_num = dist_t->dims()[1]; + PADDLE_ENFORCE_EQ(anchor_num, anchor_t->dims()[0]); + + int rpn_batch_size = context.Attr("rpn_batch_size_per_im"); + float pos_threshold = context.Attr("rpn_positive_overlap"); + float neg_threshold = context.Attr("rpn_negative_overlap"); + float fg_fraction = context.Attr("fg_fraction"); + + int fg_num_per_batch = static_cast(rpn_batch_size * fg_fraction); + + int64_t max_num = batch_num * anchor_num; + auto place = context.GetPlace(); + + tgt_bbox_t->mutable_data({max_num, 4}, place); + auto* loc_index = loc_index_t->mutable_data({max_num}, place); + auto* score_index = score_index_t->mutable_data({max_num}, place); + + Tensor tmp_tgt_lbl; + auto* tmp_lbl_data = tmp_tgt_lbl.mutable_data({max_num}, place); + auto& dev_ctx = context.device_context(); + math::SetConstant iset; + iset(dev_ctx, &tmp_tgt_lbl, static_cast(-1)); + + std::random_device rnd; + std::minstd_rand engine; + int seed = + context.Attr("fix_seed") ? context.Attr("seed") : rnd(); + engine.seed(seed); + + int fg_num = 0; + int bg_num = 0; + for (int i = 0; i < batch_num; ++i) { + Tensor dist = dist_t->Slice(lod[i], lod[i + 1]); + Tensor gt_bbox = gt_bbox_t->Slice(lod[i], lod[i + 1]); + auto fg_bg_gt = SampleFgBgGt(dev_ctx, dist, pos_threshold, neg_threshold, + rpn_batch_size, fg_num_per_batch, engine, + tmp_lbl_data + i * anchor_num); + + int cur_fg_num = fg_bg_gt[0].size(); + int cur_bg_num = fg_bg_gt[1].size(); + std::transform(fg_bg_gt[0].begin(), fg_bg_gt[0].end(), loc_index, + [i, anchor_num](int d) { return d + i * anchor_num; }); + memcpy(score_index, loc_index, cur_fg_num * sizeof(int)); + std::transform(fg_bg_gt[1].begin(), fg_bg_gt[1].end(), + score_index + cur_fg_num, + [i, anchor_num](int d) { return d + i * anchor_num; }); + + // get target bbox deltas + if (cur_fg_num) { + Tensor fg_gt; + T* gt_data = fg_gt.mutable_data({cur_fg_num, 4}, place); + Tensor tgt_bbox = tgt_bbox_t->Slice(fg_num, fg_num + cur_fg_num); + T* tgt_data = tgt_bbox.data(); + Gather(anchor_t->data(), 4, + reinterpret_cast(&fg_bg_gt[0][0]), cur_fg_num, + tgt_data); + Gather(gt_bbox.data(), 4, reinterpret_cast(&fg_bg_gt[2][0]), + cur_fg_num, gt_data); + BoxToDelta(cur_fg_num, tgt_bbox, fg_gt, nullptr, false, &tgt_bbox); + } + + loc_index += cur_fg_num; + score_index += cur_fg_num + cur_bg_num; + fg_num += cur_fg_num; + bg_num += cur_bg_num; + } + + int lbl_num = fg_num + bg_num; + PADDLE_ENFORCE_LE(fg_num, max_num); + PADDLE_ENFORCE_LE(lbl_num, max_num); + + tgt_bbox_t->Resize({fg_num, 4}); + loc_index_t->Resize({fg_num}); + score_index_t->Resize({lbl_num}); + auto* lbl_data = tgt_lbl_t->mutable_data({lbl_num, 1}, place); + Gather(tmp_lbl_data, 1, score_index_t->data(), lbl_num, + lbl_data); + } + + private: void ScoreAssign(const T* dist_data, const Tensor& anchor_to_gt_max, const int row, const int col, const float pos_threshold, - const float neg_threshold, int64_t* target_label_data, + const float neg_threshold, int64_t* target_label, std::vector* fg_inds, std::vector* bg_inds) const { - int fg_offset = fg_inds->size(); - int bg_offset = bg_inds->size(); + float epsilon = 0.0001; for (int64_t i = 0; i < row; ++i) { const T* v = dist_data + i * col; - T max_dist = *std::max_element(v, v + col); + T max = *std::max_element(v, v + col); for (int64_t j = 0; j < col; ++j) { - T val = dist_data[i * col + j]; - if (val == max_dist) target_label_data[j] = 1; + if (std::abs(max - v[j]) < epsilon) { + target_label[j] = 1; + } } } - // Pick the fg/bg and count the number + // Pick the fg/bg + const T* anchor_to_gt_max_data = anchor_to_gt_max.data(); for (int64_t j = 0; j < col; ++j) { - if (anchor_to_gt_max.data()[j] > pos_threshold) { - target_label_data[j] = 1; - } else if (anchor_to_gt_max.data()[j] < neg_threshold) { - target_label_data[j] = 0; + if (anchor_to_gt_max_data[j] >= pos_threshold) { + target_label[j] = 1; + } else if (anchor_to_gt_max_data[j] < neg_threshold) { + target_label[j] = 0; } - if (target_label_data[j] == 1) { - fg_inds->push_back(fg_offset + j); - } else if (target_label_data[j] == 0) { - bg_inds->push_back(bg_offset + j); + if (target_label[j] == 1) { + fg_inds->push_back(j); + } else if (target_label[j] == 0) { + bg_inds->push_back(j); } } } - void ReservoirSampling(const int num, const int offset, - std::minstd_rand engine, + void ReservoirSampling(const int num, std::minstd_rand engine, std::vector* inds) const { std::uniform_real_distribution uniform(0, 1); - const int64_t size = static_cast(inds->size() - offset); - if (size > num) { - for (int64_t i = num; i < size; ++i) { + size_t len = inds->size(); + if (len > static_cast(num)) { + for (size_t i = num; i < len; ++i) { int rng_ind = std::floor(uniform(engine) * i); if (rng_ind < num) - std::iter_swap(inds->begin() + rng_ind + offset, - inds->begin() + i + offset); + std::iter_swap(inds->begin() + rng_ind, inds->begin() + i); } + inds->resize(num); } } - void RpnTargetAssign(const framework::ExecutionContext& ctx, - const Tensor& dist, const float pos_threshold, - const float neg_threshold, const int rpn_batch_size, - const int fg_num, std::minstd_rand engine, - std::vector* fg_inds, std::vector* bg_inds, - int64_t* target_label_data) const { + // std::vector> RpnTargetAssign( + std::vector> SampleFgBgGt( + const platform::CPUDeviceContext& ctx, const Tensor& dist, + const float pos_threshold, const float neg_threshold, + const int rpn_batch_size, const int fg_num, std::minstd_rand engine, + int64_t* target_label) const { auto* dist_data = dist.data(); - int64_t row = dist.dims()[0]; - int64_t col = dist.dims()[1]; - int fg_offset = fg_inds->size(); - int bg_offset = bg_inds->size(); + int row = dist.dims()[0]; + int col = dist.dims()[1]; + + std::vector fg_inds; + std::vector bg_inds; + std::vector gt_inds; // Calculate the max IoU between anchors and gt boxes - Tensor anchor_to_gt_max; - anchor_to_gt_max.mutable_data( - framework::make_ddim({static_cast(col), 1}), - platform::CPUPlace()); - auto& place = *ctx.template device_context() - .eigen_device(); - auto x = EigenMatrix::From(dist); - auto x_col_max = EigenMatrix::From(anchor_to_gt_max); - x_col_max.device(place) = - x.maximum(Eigen::DSizes(0)) - .reshape(Eigen::DSizes(static_cast(col), 1)); + // Map from anchor to gt box that has highest overlap + auto place = ctx.GetPlace(); + Tensor anchor_to_gt_max, anchor_to_gt_argmax; + anchor_to_gt_max.mutable_data({col}, place); + int* argmax = anchor_to_gt_argmax.mutable_data({col}, place); + + auto x = framework::EigenMatrix::From(dist); + auto x_col_max = framework::EigenVector::Flatten(anchor_to_gt_max); + auto x_col_argmax = + framework::EigenVector::Flatten(anchor_to_gt_argmax); + x_col_max = x.maximum(Eigen::DSizes(0)); + x_col_argmax = x.argmax(0).template cast(); + // Follow the Faster RCNN's implementation ScoreAssign(dist_data, anchor_to_gt_max, row, col, pos_threshold, - neg_threshold, target_label_data, fg_inds, bg_inds); + neg_threshold, target_label, &fg_inds, &bg_inds); // Reservoir Sampling - ReservoirSampling(fg_num, fg_offset, engine, fg_inds); - int bg_num = rpn_batch_size - (fg_inds->size() - fg_offset); - ReservoirSampling(bg_num, bg_offset, engine, bg_inds); - } + ReservoirSampling(fg_num, engine, &fg_inds); + int fg_num2 = static_cast(fg_inds.size()); + int bg_num = rpn_batch_size - fg_num2; + ReservoirSampling(bg_num, engine, &bg_inds); - void Compute(const framework::ExecutionContext& context) const override { - auto* dist = context.Input("DistMat"); - auto* loc_index = context.Output("LocationIndex"); - auto* score_index = context.Output("ScoreIndex"); - auto* tgt_lbl = context.Output("TargetLabel"); - - auto col = dist->dims()[1]; - int64_t n = dist->lod().size() == 0UL - ? 1 - : static_cast(dist->lod().back().size() - 1); - if (dist->lod().size()) { - PADDLE_ENFORCE_EQ(dist->lod().size(), 1UL, - "Only support 1 level of LoD."); + gt_inds.reserve(fg_num2); + for (int i = 0; i < fg_num2; ++i) { + gt_inds.emplace_back(argmax[fg_inds[i]]); } - int rpn_batch_size = context.Attr("rpn_batch_size_per_im"); - float pos_threshold = context.Attr("rpn_positive_overlap"); - float neg_threshold = context.Attr("rpn_negative_overlap"); - float fg_fraction = context.Attr("fg_fraction"); - - int fg_num = static_cast(rpn_batch_size * fg_fraction); - - int64_t* target_label_data = - tgt_lbl->mutable_data({n * col, 1}, context.GetPlace()); + std::vector> fg_bg_gt; + fg_bg_gt.emplace_back(fg_inds); + fg_bg_gt.emplace_back(bg_inds); + fg_bg_gt.emplace_back(gt_inds); - auto& dev_ctx = context.device_context(); - math::SetConstant iset; - iset(dev_ctx, tgt_lbl, static_cast(-1)); - - std::vector fg_inds; - std::vector bg_inds; - std::random_device rnd; - std::minstd_rand engine; - int seed = - context.Attr("fix_seed") ? context.Attr("seed") : rnd(); - engine.seed(seed); - - if (n == 1) { - RpnTargetAssign(context, *dist, pos_threshold, neg_threshold, - rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds, - target_label_data); - } else { - auto lod = dist->lod().back(); - for (size_t i = 0; i < lod.size() - 1; ++i) { - Tensor one_ins = dist->Slice(lod[i], lod[i + 1]); - RpnTargetAssign(context, one_ins, pos_threshold, neg_threshold, - rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds, - target_label_data + i * col); - } - } - int* loc_index_data = loc_index->mutable_data( - {static_cast(fg_inds.size())}, context.GetPlace()); - int* score_index_data = score_index->mutable_data( - {static_cast(fg_inds.size() + bg_inds.size())}, - context.GetPlace()); - memcpy(loc_index_data, reinterpret_cast(&fg_inds[0]), - fg_inds.size() * sizeof(int)); - memcpy(score_index_data, reinterpret_cast(&fg_inds[0]), - fg_inds.size() * sizeof(int)); - memcpy(score_index_data + fg_inds.size(), - reinterpret_cast(&bg_inds[0]), bg_inds.size() * sizeof(int)); + return fg_bg_gt; } }; class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { + AddInput("Anchor", + "(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4]."); + AddInput("GtBox", "(LoDTensor) input groud-truth bbox with shape [K, 4]."); AddInput( "DistMat", "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " @@ -241,12 +305,15 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { "ScoreIndex", "(Tensor), The indexes of foreground and background anchors in all " "RPN anchors(The rest anchors are ignored). The shape of the " - "ScoreIndex is [F + B], F and B depend on the value of input " - "tensor and attributes."); - AddOutput("TargetLabel", - "(Tensor), The target labels of each anchor with shape " - "[K * M, 1], " - "K and M is the same as they are in DistMat."); + "ScoreIndex is [F + B], F and B are sampled foreground and backgroud " + " number."); + AddOutput("TargetBBox", + "(Tensor), The target bbox deltas with shape " + "[F, 4], F is the sampled foreground number."); + AddOutput( + "TargetLabel", + "(Tensor), The target labels of each anchor with shape " + "[F + B, 1], F and B are sampled foreground and backgroud number."); AddComment(R"DOC( This operator can be, for given the IoU between the ground truth bboxes and the anchors, to assign classification and regression targets to each prediction. diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index 50450b62f7b..46e20285db6 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -31,7 +31,7 @@ static inline int NumBlocks(const int N) { template __global__ void GPUROIPoolForward( - const int nthreads, const T* input_data, const int64_t* input_rois, + const int nthreads, const T* input_data, const T* input_rois, const float spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, int* roi_batch_id_data, T* output_data, int64_t* argmax_data) { @@ -43,7 +43,7 @@ __global__ void GPUROIPoolForward( int c = (i / pooled_width / pooled_height) % channels; int n = i / pooled_width / pooled_height / channels; - const int64_t* offset_input_rois = input_rois + n * kROISize; + const T* offset_input_rois = input_rois + n * kROISize; int roi_batch_ind = roi_batch_id_data[n]; int roi_start_w = round(offset_input_rois[0] * spatial_scale); int roi_start_h = round(offset_input_rois[1] * spatial_scale); @@ -93,7 +93,7 @@ __global__ void GPUROIPoolForward( template __global__ void GPUROIPoolBackward( - const int nthreads, const int64_t* input_rois, const T* output_grad, + const int nthreads, const T* input_rois, const T* output_grad, const int64_t* argmax_data, const int num_rois, const float spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, int* roi_batch_id_data, @@ -174,8 +174,8 @@ class GPUROIPoolOpKernel : public framework::OpKernel { GPUROIPoolForward< T><<>>( - output_size, in->data(), rois->data(), spatial_scale, - channels, height, width, pooled_height, pooled_width, + output_size, in->data(), rois->data(), spatial_scale, channels, + height, width, pooled_height, pooled_width, roi_batch_id_list_gpu.data(), out->mutable_data(ctx.GetPlace()), argmax->mutable_data(ctx.GetPlace())); } @@ -228,7 +228,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel { if (output_grad_size > 0) { GPUROIPoolBackward< T><<>>( - output_grad_size, rois->data(), out_grad->data(), + output_grad_size, rois->data(), out_grad->data(), argmax->data(), rois_num, spatial_scale, channels, height, width, pooled_height, pooled_width, roi_batch_id_list_gpu.data(), diff --git a/paddle/fluid/operators/roi_pool_op.h b/paddle/fluid/operators/roi_pool_op.h index c4f739b2c6b..07de7c9f0e0 100644 --- a/paddle/fluid/operators/roi_pool_op.h +++ b/paddle/fluid/operators/roi_pool_op.h @@ -72,7 +72,7 @@ class CPUROIPoolOpKernel : public framework::OpKernel { T* output_data = out->mutable_data(ctx.GetPlace()); int64_t* argmax_data = argmax->mutable_data(ctx.GetPlace()); - const int64_t* rois_data = rois->data(); + const T* rois_data = rois->data(); for (int n = 0; n < rois_num; ++n) { int roi_batch_id = roi_batch_id_data[n]; int roi_start_w = round(rois_data[0] * spatial_scale); @@ -171,7 +171,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel { } } - const int64_t* rois_data = rois->data(); + const T* rois_data = rois->data(); const T* out_grad_data = out_grad->data(); const int64_t* argmax_data = argmax->data(); T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 5757b2798e4..1bc1dbbecac 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -145,26 +145,23 @@ def rpn_target_assign(loc, """ helper = LayerHelper('rpn_target_assign', **locals()) - # 1. Compute the regression target bboxes - target_bbox = box_coder( - prior_box=anchor_box, - prior_box_var=anchor_var, - target_box=gt_box, - code_type='encode_center_size', - box_normalized=False) - # 2. Compute overlaps between the prior boxes and the gt boxes overlaps + # Compute overlaps between the prior boxes and the gt boxes overlaps iou = iou_similarity(x=gt_box, y=anchor_box) - # 3. Assign target label to anchors - loc_index = helper.create_tmp_variable(dtype=anchor_box.dtype) - score_index = helper.create_tmp_variable(dtype=anchor_box.dtype) - target_label = helper.create_tmp_variable(dtype=anchor_box.dtype) + # Assign target label to anchors + loc_index = helper.create_tmp_variable(dtype='int32') + score_index = helper.create_tmp_variable(dtype='int32') + target_label = helper.create_tmp_variable(dtype='int64') + target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype) helper.append_op( type="rpn_target_assign", - inputs={'DistMat': iou}, + inputs={'Anchor': anchor_box, + 'GtBox': gt_box, + 'DistMat': iou}, outputs={ 'LocationIndex': loc_index, 'ScoreIndex': score_index, - 'TargetLabel': target_label + 'TargetLabel': target_label, + 'TargetBBox': target_bbox, }, attrs={ 'rpn_batch_size_per_im': rpn_batch_size_per_im, @@ -173,16 +170,16 @@ def rpn_target_assign(loc, 'fg_fraction': fg_fraction }) - # 4. Reshape and gather the target entry - scores = nn.reshape(x=scores, shape=(-1, 2)) - loc = nn.reshape(x=loc, shape=(-1, 4)) - target_label = nn.reshape(x=target_label, shape=(-1, 1)) - target_bbox = nn.reshape(x=target_bbox, shape=(-1, 4)) + loc_index.stop_gradient = True + score_index.stop_gradient = True + target_label.stop_gradient = True + target_bbox.stop_gradient = True + scores = nn.reshape(x=scores, shape=(-1, 1)) + loc = nn.reshape(x=loc, shape=(-1, 4)) predicted_scores = nn.gather(scores, score_index) predicted_location = nn.gather(loc, loc_index) - target_label = nn.gather(target_label, score_index) - target_bbox = nn.gather(target_bbox, loc_index) + return predicted_scores, predicted_location, target_label, target_bbox diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index ec0bf3ff8d6..e2564763d19 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -281,7 +281,7 @@ class TestRpnTargetAssign(unittest.TestCase): gt_box = layers.data( name='gt_box', shape=[4], lod_level=1, dtype='float32') - predicted_scores, predicted_location, target_label, target_bbox = layers.rpn_target_assign( + pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign( loc=loc, scores=scores, anchor_box=anchor_box, @@ -292,15 +292,13 @@ class TestRpnTargetAssign(unittest.TestCase): rpn_positive_overlap=0.7, rpn_negative_overlap=0.3) - self.assertIsNotNone(predicted_scores) - self.assertIsNotNone(predicted_location) - self.assertIsNotNone(target_label) - self.assertIsNotNone(target_bbox) - assert predicted_scores.shape[1] == 2 - assert predicted_location.shape[1] == 4 - assert predicted_location.shape[1] == target_bbox.shape[1] - - print(str(program)) + self.assertIsNotNone(pred_scores) + self.assertIsNotNone(pred_loc) + self.assertIsNotNone(tgt_lbl) + self.assertIsNotNone(tgt_bbox) + assert pred_scores.shape[1] == 1 + assert pred_loc.shape[1] == 4 + assert pred_loc.shape[1] == tgt_bbox.shape[1] class TestGenerateProposals(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposal_labels.py b/python/paddle/fluid/tests/unittests/test_generate_proposal_labels.py index ce766fffbce..6dc101b6dad 100644 --- a/python/paddle/fluid/tests/unittests/test_generate_proposal_labels.py +++ b/python/paddle/fluid/tests/unittests/test_generate_proposal_labels.py @@ -177,8 +177,8 @@ def _box_to_delta(ex_boxes, gt_boxes, weights): dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0] dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1] - dw = (np.log(gt_w / ex_w)) / ex_w / weights[2] - dh = (np.log(gt_h / ex_h)) / ex_h / weights[3] + dw = (np.log(gt_w / ex_w)) / weights[2] + dh = (np.log(gt_h / ex_h)) / weights[3] targets = np.vstack([dx, dy, dw, dh]).transpose() return targets diff --git a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py index ed7f467835f..ad4cd2e803b 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py @@ -61,7 +61,7 @@ class TestROIPoolOp(OpTest): for i in range(self.rois_num): roi = self.rois[i] - roi_batch_id = roi[0] + roi_batch_id = int(roi[0]) roi_start_w = int(cpt.round(roi[1] * self.spatial_scale)) roi_start_h = int(cpt.round(roi[2] * self.spatial_scale)) roi_end_w = int(cpt.round(roi[3] * self.spatial_scale)) @@ -125,7 +125,7 @@ class TestROIPoolOp(OpTest): roi = [bno, x1, y1, x2, y2] rois.append(roi) self.rois_num = len(rois) - self.rois = np.array(rois).astype("int64") + self.rois = np.array(rois).astype("float32") def setUp(self): self.op_type = "roi_pool" diff --git a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py index 08c462d9036..bd548009b3a 100644 --- a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py @@ -18,12 +18,17 @@ import unittest import numpy as np import paddle.fluid.core as core from op_test import OpTest +from test_anchor_generator_op import anchor_generator_in_python +from test_generate_proposal_labels import _generate_groundtruth +from test_generate_proposal_labels import _bbox_overlaps, _box_to_delta -def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap, - rpn_negative_overlap, fg_fraction): - iou = np.transpose(iou) +def rpn_target_assign(gt_anchor_iou, rpn_batch_size_per_im, + rpn_positive_overlap, rpn_negative_overlap, fg_fraction): + iou = np.transpose(gt_anchor_iou) anchor_to_gt_max = iou.max(axis=1) + anchor_to_gt_argmax = iou.argmax(axis=1) + gt_to_anchor_argmax = iou.argmax(axis=0) gt_to_anchor_max = iou[gt_to_anchor_argmax, np.arange(iou.shape[1])] anchors_with_max_overlap = np.where(iou == gt_to_anchor_max)[0] @@ -42,59 +47,113 @@ def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap, num_bg = rpn_batch_size_per_im - np.sum(tgt_lbl == 1) bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0] + tgt_lbl[bg_inds] = 0 if len(bg_inds) > num_bg: enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)] tgt_lbl[enable_inds] = 0 bg_inds = np.where(tgt_lbl == 0)[0] + tgt_lbl[bg_inds] = 0 loc_index = fg_inds score_index = np.hstack((fg_inds, bg_inds)) tgt_lbl = np.expand_dims(tgt_lbl, axis=1) - return loc_index, score_index, tgt_lbl + + gt_inds = anchor_to_gt_argmax[fg_inds] + + return loc_index, score_index, tgt_lbl, gt_inds + + +def get_anchor(n, c, h, w): + input_feat = np.random.random((n, c, h, w)).astype('float32') + anchors, _ = anchor_generator_in_python( + input_feat=input_feat, + anchor_sizes=[32., 64.], + aspect_ratios=[0.5, 1.0], + variances=[1.0, 1.0, 1.0, 1.0], + stride=[16.0, 16.0], + offset=0.5) + return anchors + + +def rpn_blob(anchor, gt_boxes, iou, lod, rpn_batch_size_per_im, + rpn_positive_overlap, rpn_negative_overlap, fg_fraction): + + loc_indexes = [] + score_indexes = [] + tmp_tgt_labels = [] + tgt_bboxes = [] + anchor_num = anchor.shape[0] + + batch_size = len(lod) - 1 + for i in range(batch_size): + b, e = lod[i], lod[i + 1] + iou_slice = iou[b:e, :] + bboxes_slice = gt_boxes[b:e, :] + + loc_idx, score_idx, tgt_lbl, gt_inds = rpn_target_assign( + iou_slice, rpn_batch_size_per_im, rpn_positive_overlap, + rpn_negative_overlap, fg_fraction) + + fg_bboxes = bboxes_slice[gt_inds] + fg_anchors = anchor[loc_idx] + box_deltas = _box_to_delta(fg_anchors, fg_bboxes, [1., 1., 1., 1.]) + + if i == 0: + loc_indexes = loc_idx + score_indexes = score_idx + tmp_tgt_labels = tgt_lbl + tgt_bboxes = box_deltas + else: + loc_indexes = np.concatenate( + [loc_indexes, loc_idx + i * anchor_num]) + score_indexes = np.concatenate( + [score_indexes, score_idx + i * anchor_num]) + tmp_tgt_labels = np.concatenate([tmp_tgt_labels, tgt_lbl]) + tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) + + tgt_labels = tmp_tgt_labels[score_indexes] + return loc_indexes, score_indexes, tgt_bboxes, tgt_labels class TestRpnTargetAssignOp(OpTest): def setUp(self): - iou = np.random.random((10, 8)).astype("float32") - self.op_type = "rpn_target_assign" - self.inputs = {'DistMat': iou} - self.attrs = { - 'rpn_batch_size_per_im': 256, - 'rpn_positive_overlap': 0.95, - 'rpn_negative_overlap': 0.3, - 'fg_fraction': 0.25, - 'fix_seed': True - } - loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 256, 0.95, 0.3, - 0.25) - self.outputs = { - 'LocationIndex': loc_index, - 'ScoreIndex': score_index, - 'TargetLabel': tgt_lbl, - } + n, c, h, w = 2, 4, 14, 14 + anchor = get_anchor(n, c, h, w) + gt_num = 10 + anchor = anchor.reshape(-1, 4) + anchor_num = anchor.shape[0] - def test_check_output(self): - self.check_output() + im_shapes = [[64, 64], [64, 64]] + gt_box, lod = _generate_groundtruth(im_shapes, 3, 4) + bbox = np.vstack([v['boxes'] for v in gt_box]) + iou = _bbox_overlaps(bbox, anchor) + + anchor = anchor.astype('float32') + bbox = bbox.astype('float32') + iou = iou.astype('float32') + + loc_index, score_index, tgt_bbox, tgt_lbl = rpn_blob( + anchor, bbox, iou, [0, 4, 8], 25600, 0.95, 0.03, 0.25) -class TestRpnTargetAssignOp2(OpTest): - def setUp(self): - iou = np.random.random((10, 20)).astype("float32") self.op_type = "rpn_target_assign" - self.inputs = {'DistMat': iou} + self.inputs = { + 'Anchor': anchor, + 'GtBox': (bbox, [[4, 4]]), + 'DistMat': (iou, [[4, 4]]), + } self.attrs = { - 'rpn_batch_size_per_im': 128, - 'rpn_positive_overlap': 0.5, - 'rpn_negative_overlap': 0.5, - 'fg_fraction': 0.5, + 'rpn_batch_size_per_im': 25600, + 'rpn_positive_overlap': 0.95, + 'rpn_negative_overlap': 0.03, + 'fg_fraction': 0.25, 'fix_seed': True } - loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 128, 0.5, 0.5, - 0.5) self.outputs = { - 'LocationIndex': loc_index, - 'ScoreIndex': score_index, - 'TargetLabel': tgt_lbl, + 'LocationIndex': loc_index.astype('int32'), + 'ScoreIndex': score_index.astype('int32'), + 'TargetBBox': tgt_bbox.astype('float32'), + 'TargetLabel': tgt_lbl.astype('int64'), } def test_check_output(self): -- GitLab