未验证 提交 9557cc21 编写于 作者: Q qingqing01 提交者: GitHub

Refine and fix some code for faster-rcnn. (#13135)

* Fix bug in generate_proposals_op.
* Fix data type for RoIs.
* Refine and fix rpn_target_assign_op.
* Add the missing file bbox_util.h
* Rename BoxEncoder to BoxToDelta
上级 d9c3123a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
/*
* transform that computes target bounding-box regression deltas
* given proposal boxes and ground-truth boxes.
*/
template <typename T>
inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes,
const framework::Tensor& gt_boxes, const T* weights,
const bool normalized, framework::Tensor* box_delta) {
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
auto trg = framework::EigenTensor<T, 2>::From(*box_delta);
T ex_w, ex_h, ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y;
for (int64_t i = 0; i < box_num; ++i) {
ex_w = ex_boxes_et(i, 2) - ex_boxes_et(i, 0) + (normalized == false);
ex_h = ex_boxes_et(i, 3) - ex_boxes_et(i, 1) + (normalized == false);
ex_ctr_x = ex_boxes_et(i, 0) + 0.5 * ex_w;
ex_ctr_y = ex_boxes_et(i, 1) + 0.5 * ex_h;
gt_w = gt_boxes_et(i, 2) - gt_boxes_et(i, 0) + (normalized == false);
gt_h = gt_boxes_et(i, 3) - gt_boxes_et(i, 1) + (normalized == false);
gt_ctr_x = gt_boxes_et(i, 0) + 0.5 * gt_w;
gt_ctr_y = gt_boxes_et(i, 1) + 0.5 * gt_h;
trg(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w;
trg(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h;
trg(i, 2) = std::log(gt_w / ex_w);
trg(i, 3) = std::log(gt_h / ex_h);
if (weights) {
trg(i, 0) = trg(i, 0) / weights[0];
trg(i, 1) = trg(i, 1) / weights[1];
trg(i, 2) = trg(i, 2) / weights[2];
trg(i, 3) = trg(i, 3) / weights[3];
}
}
}
template <typename T>
void Gather(const T* in, const int in_stride, const int* index, const int num,
T* out) {
const int stride_bytes = in_stride * sizeof(T);
for (int i = 0; i < num; ++i) {
int id = index[i];
memcpy(out + i * in_stride, in + id * in_stride, stride_bytes);
}
}
} // namespace operators
} // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -133,31 +134,6 @@ void BboxOverlaps(const Tensor& r_boxes, const Tensor& c_boxes, ...@@ -133,31 +134,6 @@ void BboxOverlaps(const Tensor& r_boxes, const Tensor& c_boxes,
} }
} }
template <typename T>
void BoxToDelta(int box_num, const Tensor& ex_boxes, const Tensor& gt_boxes,
const std::vector<float>& weights, Tensor* box_delta) {
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
auto box_delta_et = framework::EigenTensor<T, 2>::From(*box_delta);
T ex_w, ex_h, ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y;
for (int64_t i = 0; i < box_num; ++i) {
ex_w = ex_boxes_et(i, 2) - ex_boxes_et(i, 0) + 1;
ex_h = ex_boxes_et(i, 3) - ex_boxes_et(i, 1) + 1;
ex_ctr_x = ex_boxes_et(i, 0) + 0.5 * ex_w;
ex_ctr_y = ex_boxes_et(i, 1) + 0.5 * ex_h;
gt_w = gt_boxes_et(i, 2) - gt_boxes_et(i, 0) + 1;
gt_h = gt_boxes_et(i, 3) - gt_boxes_et(i, 1) + 1;
gt_ctr_x = gt_boxes_et(i, 0) + 0.5 * gt_w;
gt_ctr_y = gt_boxes_et(i, 1) + 0.5 * gt_h;
box_delta_et(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0];
box_delta_et(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1];
box_delta_et(i, 2) = log(gt_w / ex_w) / ex_w / weights[2];
box_delta_et(i, 3) = log(gt_h / ex_h) / ex_h / weights[3];
}
}
template <typename T> template <typename T>
std::vector<std::vector<int>> SampleFgBgGt( std::vector<std::vector<int>> SampleFgBgGt(
const platform::CPUDeviceContext& context, Tensor* iou, const platform::CPUDeviceContext& context, Tensor* iou,
...@@ -243,12 +219,11 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context, ...@@ -243,12 +219,11 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context,
Tensor* sampled_labels, Tensor* sampled_gts) { Tensor* sampled_labels, Tensor* sampled_gts) {
int fg_num = fg_inds.size(); int fg_num = fg_inds.size();
int bg_num = bg_inds.size(); int bg_num = bg_inds.size();
int gt_num = fg_num + bg_num;
Tensor fg_inds_t, bg_inds_t, gt_box_inds_t, gt_label_inds_t; Tensor fg_inds_t, bg_inds_t, gt_box_inds_t, gt_label_inds_t;
int* fg_inds_data = fg_inds_t.mutable_data<int>({fg_num}, context.GetPlace()); int* fg_inds_data = fg_inds_t.mutable_data<int>({fg_num}, context.GetPlace());
int* bg_inds_data = bg_inds_t.mutable_data<int>({bg_num}, context.GetPlace()); int* bg_inds_data = bg_inds_t.mutable_data<int>({bg_num}, context.GetPlace());
int* gt_box_inds_data = int* gt_box_inds_data =
gt_box_inds_t.mutable_data<int>({gt_num}, context.GetPlace()); gt_box_inds_t.mutable_data<int>({fg_num}, context.GetPlace());
int* gt_label_inds_data = int* gt_label_inds_data =
gt_label_inds_t.mutable_data<int>({fg_num}, context.GetPlace()); gt_label_inds_t.mutable_data<int>({fg_num}, context.GetPlace());
std::copy(fg_inds.begin(), fg_inds.end(), fg_inds_data); std::copy(fg_inds.begin(), fg_inds.end(), fg_inds_data);
...@@ -303,18 +278,20 @@ std::vector<Tensor> SampleRoisForOneImage( ...@@ -303,18 +278,20 @@ std::vector<Tensor> SampleRoisForOneImage(
// Gather boxes and labels // Gather boxes and labels
Tensor sampled_boxes, sampled_labels, sampled_gts; Tensor sampled_boxes, sampled_labels, sampled_gts;
int boxes_num = fg_inds.size() + bg_inds.size(); int fg_num = fg_inds.size();
int bg_num = bg_inds.size();
int boxes_num = fg_num + bg_num;
framework::DDim bbox_dim({boxes_num, kBoxDim}); framework::DDim bbox_dim({boxes_num, kBoxDim});
sampled_boxes.mutable_data<T>(bbox_dim, context.GetPlace()); sampled_boxes.mutable_data<T>(bbox_dim, context.GetPlace());
sampled_labels.mutable_data<int>({boxes_num}, context.GetPlace()); sampled_labels.mutable_data<int>({boxes_num}, context.GetPlace());
sampled_gts.mutable_data<T>(bbox_dim, context.GetPlace()); sampled_gts.mutable_data<T>({fg_num, kBoxDim}, context.GetPlace());
GatherBoxesLabels<T>(context, boxes, *gt_boxes, *gt_classes, fg_inds, bg_inds, GatherBoxesLabels<T>(context, boxes, *gt_boxes, *gt_classes, fg_inds, bg_inds,
gt_inds, &sampled_boxes, &sampled_labels, &sampled_gts); gt_inds, &sampled_boxes, &sampled_labels, &sampled_gts);
// Compute targets // Compute targets
Tensor bbox_targets_single; Tensor bbox_targets_single;
bbox_targets_single.mutable_data<T>(bbox_dim, context.GetPlace()); bbox_targets_single.mutable_data<T>(bbox_dim, context.GetPlace());
BoxToDelta<T>(boxes_num, sampled_boxes, sampled_gts, bbox_reg_weights, BoxToDelta<T>(fg_num, sampled_boxes, sampled_gts, nullptr, false,
&bbox_targets_single); &bbox_targets_single);
// Scale rois // Scale rois
...@@ -427,7 +404,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> { ...@@ -427,7 +404,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
auto rpn_rois_lod = rpn_rois->lod().back(); auto rpn_rois_lod = rpn_rois->lod().back();
auto gt_classes_lod = gt_classes->lod().back(); auto gt_classes_lod = gt_classes->lod().back();
auto gt_boxes_lod = gt_boxes->lod().back(); auto gt_boxes_lod = gt_boxes->lod().back();
for (size_t i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
Tensor rpn_rois_slice = Tensor rpn_rois_slice =
rpn_rois->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]); rpn_rois->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]);
Tensor gt_classes_slice = Tensor gt_classes_slice =
......
...@@ -311,8 +311,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -311,8 +311,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4}, rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace()); context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel() / 4, 1}, rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());
context.GetPlace());
Tensor bbox_deltas_swap, scores_swap; Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox}, bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
...@@ -421,7 +420,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -421,7 +420,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
CPUGather<T>(ctx, proposals, keep, &bbox_sel); CPUGather<T>(ctx, proposals, keep, &bbox_sel);
CPUGather<T>(ctx, scores_sel, keep, &scores_filter); CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
if (nms_thresh <= 0) { if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_sel); return std::make_pair(bbox_sel, scores_filter);
} }
Tensor keep_nms = NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta); Tensor keep_nms = NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <random> #include <random>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -46,156 +47,219 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -46,156 +47,219 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
auto in_dims = ctx->GetInputDim("DistMat"); auto in_dims = ctx->GetInputDim("DistMat");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, PADDLE_ENFORCE_EQ(in_dims.size(), 2,
"The rank of Input(DistMat) must be 2."); "The rank of Input(DistMat) must be 2.");
ctx->SetOutputDim("LocationIndex", {-1});
ctx->SetOutputDim("ScoreIndex", {-1});
ctx->SetOutputDim("TargetLabel", {-1, 1});
ctx->SetOutputDim("TargetBBox", {-1, 4});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("DistMat")->type()),
platform::CPUPlace());
} }
}; };
template <typename T> template <typename T>
class RpnTargetAssignKernel : public framework::OpKernel<T> { class RpnTargetAssignKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override {
auto* anchor_t = context.Input<Tensor>("Anchor"); // (H*W*A) * 4
auto* gt_bbox_t = context.Input<Tensor>("GtBox");
auto* dist_t = context.Input<LoDTensor>("DistMat");
auto* loc_index_t = context.Output<Tensor>("LocationIndex");
auto* score_index_t = context.Output<Tensor>("ScoreIndex");
auto* tgt_bbox_t = context.Output<Tensor>("TargetBBox");
auto* tgt_lbl_t = context.Output<Tensor>("TargetLabel");
auto lod = dist_t->lod().back();
int64_t batch_num = static_cast<int64_t>(lod.size() - 1);
int64_t anchor_num = dist_t->dims()[1];
PADDLE_ENFORCE_EQ(anchor_num, anchor_t->dims()[0]);
int rpn_batch_size = context.Attr<int>("rpn_batch_size_per_im");
float pos_threshold = context.Attr<float>("rpn_positive_overlap");
float neg_threshold = context.Attr<float>("rpn_negative_overlap");
float fg_fraction = context.Attr<float>("fg_fraction");
int fg_num_per_batch = static_cast<int>(rpn_batch_size * fg_fraction);
int64_t max_num = batch_num * anchor_num;
auto place = context.GetPlace();
tgt_bbox_t->mutable_data<T>({max_num, 4}, place);
auto* loc_index = loc_index_t->mutable_data<int>({max_num}, place);
auto* score_index = score_index_t->mutable_data<int>({max_num}, place);
Tensor tmp_tgt_lbl;
auto* tmp_lbl_data = tmp_tgt_lbl.mutable_data<int64_t>({max_num}, place);
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, int64_t> iset;
iset(dev_ctx, &tmp_tgt_lbl, static_cast<int64_t>(-1));
std::random_device rnd;
std::minstd_rand engine;
int seed =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
engine.seed(seed);
int fg_num = 0;
int bg_num = 0;
for (int i = 0; i < batch_num; ++i) {
Tensor dist = dist_t->Slice(lod[i], lod[i + 1]);
Tensor gt_bbox = gt_bbox_t->Slice(lod[i], lod[i + 1]);
auto fg_bg_gt = SampleFgBgGt(dev_ctx, dist, pos_threshold, neg_threshold,
rpn_batch_size, fg_num_per_batch, engine,
tmp_lbl_data + i * anchor_num);
int cur_fg_num = fg_bg_gt[0].size();
int cur_bg_num = fg_bg_gt[1].size();
std::transform(fg_bg_gt[0].begin(), fg_bg_gt[0].end(), loc_index,
[i, anchor_num](int d) { return d + i * anchor_num; });
memcpy(score_index, loc_index, cur_fg_num * sizeof(int));
std::transform(fg_bg_gt[1].begin(), fg_bg_gt[1].end(),
score_index + cur_fg_num,
[i, anchor_num](int d) { return d + i * anchor_num; });
// get target bbox deltas
if (cur_fg_num) {
Tensor fg_gt;
T* gt_data = fg_gt.mutable_data<T>({cur_fg_num, 4}, place);
Tensor tgt_bbox = tgt_bbox_t->Slice(fg_num, fg_num + cur_fg_num);
T* tgt_data = tgt_bbox.data<T>();
Gather<T>(anchor_t->data<T>(), 4,
reinterpret_cast<int*>(&fg_bg_gt[0][0]), cur_fg_num,
tgt_data);
Gather<T>(gt_bbox.data<T>(), 4, reinterpret_cast<int*>(&fg_bg_gt[2][0]),
cur_fg_num, gt_data);
BoxToDelta<T>(cur_fg_num, tgt_bbox, fg_gt, nullptr, false, &tgt_bbox);
}
loc_index += cur_fg_num;
score_index += cur_fg_num + cur_bg_num;
fg_num += cur_fg_num;
bg_num += cur_bg_num;
}
int lbl_num = fg_num + bg_num;
PADDLE_ENFORCE_LE(fg_num, max_num);
PADDLE_ENFORCE_LE(lbl_num, max_num);
tgt_bbox_t->Resize({fg_num, 4});
loc_index_t->Resize({fg_num});
score_index_t->Resize({lbl_num});
auto* lbl_data = tgt_lbl_t->mutable_data<int64_t>({lbl_num, 1}, place);
Gather<int64_t>(tmp_lbl_data, 1, score_index_t->data<int>(), lbl_num,
lbl_data);
}
private:
void ScoreAssign(const T* dist_data, const Tensor& anchor_to_gt_max, void ScoreAssign(const T* dist_data, const Tensor& anchor_to_gt_max,
const int row, const int col, const float pos_threshold, const int row, const int col, const float pos_threshold,
const float neg_threshold, int64_t* target_label_data, const float neg_threshold, int64_t* target_label,
std::vector<int>* fg_inds, std::vector<int>* bg_inds) const { std::vector<int>* fg_inds, std::vector<int>* bg_inds) const {
int fg_offset = fg_inds->size(); float epsilon = 0.0001;
int bg_offset = bg_inds->size();
for (int64_t i = 0; i < row; ++i) { for (int64_t i = 0; i < row; ++i) {
const T* v = dist_data + i * col; const T* v = dist_data + i * col;
T max_dist = *std::max_element(v, v + col); T max = *std::max_element(v, v + col);
for (int64_t j = 0; j < col; ++j) { for (int64_t j = 0; j < col; ++j) {
T val = dist_data[i * col + j]; if (std::abs(max - v[j]) < epsilon) {
if (val == max_dist) target_label_data[j] = 1; target_label[j] = 1;
}
} }
} }
// Pick the fg/bg and count the number // Pick the fg/bg
const T* anchor_to_gt_max_data = anchor_to_gt_max.data<T>();
for (int64_t j = 0; j < col; ++j) { for (int64_t j = 0; j < col; ++j) {
if (anchor_to_gt_max.data<T>()[j] > pos_threshold) { if (anchor_to_gt_max_data[j] >= pos_threshold) {
target_label_data[j] = 1; target_label[j] = 1;
} else if (anchor_to_gt_max.data<T>()[j] < neg_threshold) { } else if (anchor_to_gt_max_data[j] < neg_threshold) {
target_label_data[j] = 0; target_label[j] = 0;
} }
if (target_label_data[j] == 1) { if (target_label[j] == 1) {
fg_inds->push_back(fg_offset + j); fg_inds->push_back(j);
} else if (target_label_data[j] == 0) { } else if (target_label[j] == 0) {
bg_inds->push_back(bg_offset + j); bg_inds->push_back(j);
} }
} }
} }
void ReservoirSampling(const int num, const int offset, void ReservoirSampling(const int num, std::minstd_rand engine,
std::minstd_rand engine,
std::vector<int>* inds) const { std::vector<int>* inds) const {
std::uniform_real_distribution<float> uniform(0, 1); std::uniform_real_distribution<float> uniform(0, 1);
const int64_t size = static_cast<int64_t>(inds->size() - offset); size_t len = inds->size();
if (size > num) { if (len > static_cast<size_t>(num)) {
for (int64_t i = num; i < size; ++i) { for (size_t i = num; i < len; ++i) {
int rng_ind = std::floor(uniform(engine) * i); int rng_ind = std::floor(uniform(engine) * i);
if (rng_ind < num) if (rng_ind < num)
std::iter_swap(inds->begin() + rng_ind + offset, std::iter_swap(inds->begin() + rng_ind, inds->begin() + i);
inds->begin() + i + offset);
} }
inds->resize(num);
} }
} }
void RpnTargetAssign(const framework::ExecutionContext& ctx, // std::vector<std::vector<int>> RpnTargetAssign(
const Tensor& dist, const float pos_threshold, std::vector<std::vector<int>> SampleFgBgGt(
const float neg_threshold, const int rpn_batch_size, const platform::CPUDeviceContext& ctx, const Tensor& dist,
const int fg_num, std::minstd_rand engine, const float pos_threshold, const float neg_threshold,
std::vector<int>* fg_inds, std::vector<int>* bg_inds, const int rpn_batch_size, const int fg_num, std::minstd_rand engine,
int64_t* target_label_data) const { int64_t* target_label) const {
auto* dist_data = dist.data<T>(); auto* dist_data = dist.data<T>();
int64_t row = dist.dims()[0]; int row = dist.dims()[0];
int64_t col = dist.dims()[1]; int col = dist.dims()[1];
int fg_offset = fg_inds->size();
int bg_offset = bg_inds->size(); std::vector<int> fg_inds;
std::vector<int> bg_inds;
std::vector<int> gt_inds;
// Calculate the max IoU between anchors and gt boxes // Calculate the max IoU between anchors and gt boxes
Tensor anchor_to_gt_max; // Map from anchor to gt box that has highest overlap
anchor_to_gt_max.mutable_data<T>( auto place = ctx.GetPlace();
framework::make_ddim({static_cast<int64_t>(col), 1}), Tensor anchor_to_gt_max, anchor_to_gt_argmax;
platform::CPUPlace()); anchor_to_gt_max.mutable_data<T>({col}, place);
auto& place = *ctx.template device_context<platform::CPUDeviceContext>() int* argmax = anchor_to_gt_argmax.mutable_data<int>({col}, place);
.eigen_device();
auto x = EigenMatrix<T>::From(dist); auto x = framework::EigenMatrix<T>::From(dist);
auto x_col_max = EigenMatrix<T>::From(anchor_to_gt_max); auto x_col_max = framework::EigenVector<T>::Flatten(anchor_to_gt_max);
x_col_max.device(place) = auto x_col_argmax =
x.maximum(Eigen::DSizes<int, 1>(0)) framework::EigenVector<int>::Flatten(anchor_to_gt_argmax);
.reshape(Eigen::DSizes<int, 2>(static_cast<int64_t>(col), 1)); x_col_max = x.maximum(Eigen::DSizes<int, 1>(0));
x_col_argmax = x.argmax(0).template cast<int>();
// Follow the Faster RCNN's implementation // Follow the Faster RCNN's implementation
ScoreAssign(dist_data, anchor_to_gt_max, row, col, pos_threshold, ScoreAssign(dist_data, anchor_to_gt_max, row, col, pos_threshold,
neg_threshold, target_label_data, fg_inds, bg_inds); neg_threshold, target_label, &fg_inds, &bg_inds);
// Reservoir Sampling // Reservoir Sampling
ReservoirSampling(fg_num, fg_offset, engine, fg_inds); ReservoirSampling(fg_num, engine, &fg_inds);
int bg_num = rpn_batch_size - (fg_inds->size() - fg_offset); int fg_num2 = static_cast<int>(fg_inds.size());
ReservoirSampling(bg_num, bg_offset, engine, bg_inds); int bg_num = rpn_batch_size - fg_num2;
} ReservoirSampling(bg_num, engine, &bg_inds);
void Compute(const framework::ExecutionContext& context) const override { gt_inds.reserve(fg_num2);
auto* dist = context.Input<LoDTensor>("DistMat"); for (int i = 0; i < fg_num2; ++i) {
auto* loc_index = context.Output<Tensor>("LocationIndex"); gt_inds.emplace_back(argmax[fg_inds[i]]);
auto* score_index = context.Output<Tensor>("ScoreIndex");
auto* tgt_lbl = context.Output<Tensor>("TargetLabel");
auto col = dist->dims()[1];
int64_t n = dist->lod().size() == 0UL
? 1
: static_cast<int64_t>(dist->lod().back().size() - 1);
if (dist->lod().size()) {
PADDLE_ENFORCE_EQ(dist->lod().size(), 1UL,
"Only support 1 level of LoD.");
} }
int rpn_batch_size = context.Attr<int>("rpn_batch_size_per_im"); std::vector<std::vector<int>> fg_bg_gt;
float pos_threshold = context.Attr<float>("rpn_positive_overlap"); fg_bg_gt.emplace_back(fg_inds);
float neg_threshold = context.Attr<float>("rpn_negative_overlap"); fg_bg_gt.emplace_back(bg_inds);
float fg_fraction = context.Attr<float>("fg_fraction"); fg_bg_gt.emplace_back(gt_inds);
int fg_num = static_cast<int>(rpn_batch_size * fg_fraction);
int64_t* target_label_data =
tgt_lbl->mutable_data<int64_t>({n * col, 1}, context.GetPlace());
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>(); return fg_bg_gt;
math::SetConstant<platform::CPUDeviceContext, int64_t> iset;
iset(dev_ctx, tgt_lbl, static_cast<int>(-1));
std::vector<int> fg_inds;
std::vector<int> bg_inds;
std::random_device rnd;
std::minstd_rand engine;
int seed =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
engine.seed(seed);
if (n == 1) {
RpnTargetAssign(context, *dist, pos_threshold, neg_threshold,
rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds,
target_label_data);
} else {
auto lod = dist->lod().back();
for (size_t i = 0; i < lod.size() - 1; ++i) {
Tensor one_ins = dist->Slice(lod[i], lod[i + 1]);
RpnTargetAssign(context, one_ins, pos_threshold, neg_threshold,
rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds,
target_label_data + i * col);
}
}
int* loc_index_data = loc_index->mutable_data<int>(
{static_cast<int>(fg_inds.size())}, context.GetPlace());
int* score_index_data = score_index->mutable_data<int>(
{static_cast<int>(fg_inds.size() + bg_inds.size())},
context.GetPlace());
memcpy(loc_index_data, reinterpret_cast<int*>(&fg_inds[0]),
fg_inds.size() * sizeof(int));
memcpy(score_index_data, reinterpret_cast<int*>(&fg_inds[0]),
fg_inds.size() * sizeof(int));
memcpy(score_index_data + fg_inds.size(),
reinterpret_cast<int*>(&bg_inds[0]), bg_inds.size() * sizeof(int));
} }
}; };
class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Anchor",
"(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4].");
AddInput("GtBox", "(LoDTensor) input groud-truth bbox with shape [K, 4].");
AddInput( AddInput(
"DistMat", "DistMat",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
...@@ -241,12 +305,15 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -241,12 +305,15 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"ScoreIndex", "ScoreIndex",
"(Tensor), The indexes of foreground and background anchors in all " "(Tensor), The indexes of foreground and background anchors in all "
"RPN anchors(The rest anchors are ignored). The shape of the " "RPN anchors(The rest anchors are ignored). The shape of the "
"ScoreIndex is [F + B], F and B depend on the value of input " "ScoreIndex is [F + B], F and B are sampled foreground and backgroud "
"tensor and attributes."); " number.");
AddOutput("TargetLabel", AddOutput("TargetBBox",
"(Tensor<int64_t>), The target labels of each anchor with shape " "(Tensor<int64_t>), The target bbox deltas with shape "
"[K * M, 1], " "[F, 4], F is the sampled foreground number.");
"K and M is the same as they are in DistMat."); AddOutput(
"TargetLabel",
"(Tensor<int64_t>), The target labels of each anchor with shape "
"[F + B, 1], F and B are sampled foreground and backgroud number.");
AddComment(R"DOC( AddComment(R"DOC(
This operator can be, for given the IoU between the ground truth bboxes and the This operator can be, for given the IoU between the ground truth bboxes and the
anchors, to assign classification and regression targets to each prediction. anchors, to assign classification and regression targets to each prediction.
......
...@@ -31,7 +31,7 @@ static inline int NumBlocks(const int N) { ...@@ -31,7 +31,7 @@ static inline int NumBlocks(const int N) {
template <typename T> template <typename T>
__global__ void GPUROIPoolForward( __global__ void GPUROIPoolForward(
const int nthreads, const T* input_data, const int64_t* input_rois, const int nthreads, const T* input_data, const T* input_rois,
const float spatial_scale, const int channels, const int height, const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width, const int width, const int pooled_height, const int pooled_width,
int* roi_batch_id_data, T* output_data, int64_t* argmax_data) { int* roi_batch_id_data, T* output_data, int64_t* argmax_data) {
...@@ -43,7 +43,7 @@ __global__ void GPUROIPoolForward( ...@@ -43,7 +43,7 @@ __global__ void GPUROIPoolForward(
int c = (i / pooled_width / pooled_height) % channels; int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels; int n = i / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize; const T* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n]; int roi_batch_ind = roi_batch_id_data[n];
int roi_start_w = round(offset_input_rois[0] * spatial_scale); int roi_start_w = round(offset_input_rois[0] * spatial_scale);
int roi_start_h = round(offset_input_rois[1] * spatial_scale); int roi_start_h = round(offset_input_rois[1] * spatial_scale);
...@@ -93,7 +93,7 @@ __global__ void GPUROIPoolForward( ...@@ -93,7 +93,7 @@ __global__ void GPUROIPoolForward(
template <typename T> template <typename T>
__global__ void GPUROIPoolBackward( __global__ void GPUROIPoolBackward(
const int nthreads, const int64_t* input_rois, const T* output_grad, const int nthreads, const T* input_rois, const T* output_grad,
const int64_t* argmax_data, const int num_rois, const float spatial_scale, const int64_t* argmax_data, const int num_rois, const float spatial_scale,
const int channels, const int height, const int width, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, int* roi_batch_id_data, const int pooled_height, const int pooled_width, int* roi_batch_id_data,
...@@ -174,8 +174,8 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -174,8 +174,8 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
GPUROIPoolForward< GPUROIPoolForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<int64_t>(), spatial_scale, output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
channels, height, width, pooled_height, pooled_width, height, width, pooled_height, pooled_width,
roi_batch_id_list_gpu.data<int>(), out->mutable_data<T>(ctx.GetPlace()), roi_batch_id_list_gpu.data<int>(), out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace())); argmax->mutable_data<int64_t>(ctx.GetPlace()));
} }
...@@ -228,7 +228,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -228,7 +228,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
if (output_grad_size > 0) { if (output_grad_size > 0) {
GPUROIPoolBackward< GPUROIPoolBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<int64_t>(), out_grad->data<T>(), output_grad_size, rois->data<T>(), out_grad->data<T>(),
argmax->data<int64_t>(), rois_num, spatial_scale, channels, height, argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
width, pooled_height, pooled_width, width, pooled_height, pooled_width,
roi_batch_id_list_gpu.data<int>(), roi_batch_id_list_gpu.data<int>(),
......
...@@ -72,7 +72,7 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -72,7 +72,7 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
T* output_data = out->mutable_data<T>(ctx.GetPlace()); T* output_data = out->mutable_data<T>(ctx.GetPlace());
int64_t* argmax_data = argmax->mutable_data<int64_t>(ctx.GetPlace()); int64_t* argmax_data = argmax->mutable_data<int64_t>(ctx.GetPlace());
const int64_t* rois_data = rois->data<int64_t>(); const T* rois_data = rois->data<T>();
for (int n = 0; n < rois_num; ++n) { for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = roi_batch_id_data[n]; int roi_batch_id = roi_batch_id_data[n];
int roi_start_w = round(rois_data[0] * spatial_scale); int roi_start_w = round(rois_data[0] * spatial_scale);
...@@ -171,7 +171,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -171,7 +171,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
} }
} }
const int64_t* rois_data = rois->data<int64_t>(); const T* rois_data = rois->data<T>();
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
const int64_t* argmax_data = argmax->data<int64_t>(); const int64_t* argmax_data = argmax->data<int64_t>();
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace()); T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
......
...@@ -145,26 +145,23 @@ def rpn_target_assign(loc, ...@@ -145,26 +145,23 @@ def rpn_target_assign(loc,
""" """
helper = LayerHelper('rpn_target_assign', **locals()) helper = LayerHelper('rpn_target_assign', **locals())
# 1. Compute the regression target bboxes # Compute overlaps between the prior boxes and the gt boxes overlaps
target_bbox = box_coder(
prior_box=anchor_box,
prior_box_var=anchor_var,
target_box=gt_box,
code_type='encode_center_size',
box_normalized=False)
# 2. Compute overlaps between the prior boxes and the gt boxes overlaps
iou = iou_similarity(x=gt_box, y=anchor_box) iou = iou_similarity(x=gt_box, y=anchor_box)
# 3. Assign target label to anchors # Assign target label to anchors
loc_index = helper.create_tmp_variable(dtype=anchor_box.dtype) loc_index = helper.create_tmp_variable(dtype='int32')
score_index = helper.create_tmp_variable(dtype=anchor_box.dtype) score_index = helper.create_tmp_variable(dtype='int32')
target_label = helper.create_tmp_variable(dtype=anchor_box.dtype) target_label = helper.create_tmp_variable(dtype='int64')
target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype)
helper.append_op( helper.append_op(
type="rpn_target_assign", type="rpn_target_assign",
inputs={'DistMat': iou}, inputs={'Anchor': anchor_box,
'GtBox': gt_box,
'DistMat': iou},
outputs={ outputs={
'LocationIndex': loc_index, 'LocationIndex': loc_index,
'ScoreIndex': score_index, 'ScoreIndex': score_index,
'TargetLabel': target_label 'TargetLabel': target_label,
'TargetBBox': target_bbox,
}, },
attrs={ attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im, 'rpn_batch_size_per_im': rpn_batch_size_per_im,
...@@ -173,16 +170,16 @@ def rpn_target_assign(loc, ...@@ -173,16 +170,16 @@ def rpn_target_assign(loc,
'fg_fraction': fg_fraction 'fg_fraction': fg_fraction
}) })
# 4. Reshape and gather the target entry loc_index.stop_gradient = True
scores = nn.reshape(x=scores, shape=(-1, 2)) score_index.stop_gradient = True
loc = nn.reshape(x=loc, shape=(-1, 4)) target_label.stop_gradient = True
target_label = nn.reshape(x=target_label, shape=(-1, 1)) target_bbox.stop_gradient = True
target_bbox = nn.reshape(x=target_bbox, shape=(-1, 4))
scores = nn.reshape(x=scores, shape=(-1, 1))
loc = nn.reshape(x=loc, shape=(-1, 4))
predicted_scores = nn.gather(scores, score_index) predicted_scores = nn.gather(scores, score_index)
predicted_location = nn.gather(loc, loc_index) predicted_location = nn.gather(loc, loc_index)
target_label = nn.gather(target_label, score_index)
target_bbox = nn.gather(target_bbox, loc_index)
return predicted_scores, predicted_location, target_label, target_bbox return predicted_scores, predicted_location, target_label, target_bbox
......
...@@ -281,7 +281,7 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -281,7 +281,7 @@ class TestRpnTargetAssign(unittest.TestCase):
gt_box = layers.data( gt_box = layers.data(
name='gt_box', shape=[4], lod_level=1, dtype='float32') name='gt_box', shape=[4], lod_level=1, dtype='float32')
predicted_scores, predicted_location, target_label, target_bbox = layers.rpn_target_assign( pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign(
loc=loc, loc=loc,
scores=scores, scores=scores,
anchor_box=anchor_box, anchor_box=anchor_box,
...@@ -292,15 +292,13 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -292,15 +292,13 @@ class TestRpnTargetAssign(unittest.TestCase):
rpn_positive_overlap=0.7, rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3) rpn_negative_overlap=0.3)
self.assertIsNotNone(predicted_scores) self.assertIsNotNone(pred_scores)
self.assertIsNotNone(predicted_location) self.assertIsNotNone(pred_loc)
self.assertIsNotNone(target_label) self.assertIsNotNone(tgt_lbl)
self.assertIsNotNone(target_bbox) self.assertIsNotNone(tgt_bbox)
assert predicted_scores.shape[1] == 2 assert pred_scores.shape[1] == 1
assert predicted_location.shape[1] == 4 assert pred_loc.shape[1] == 4
assert predicted_location.shape[1] == target_bbox.shape[1] assert pred_loc.shape[1] == tgt_bbox.shape[1]
print(str(program))
class TestGenerateProposals(unittest.TestCase): class TestGenerateProposals(unittest.TestCase):
......
...@@ -177,8 +177,8 @@ def _box_to_delta(ex_boxes, gt_boxes, weights): ...@@ -177,8 +177,8 @@ def _box_to_delta(ex_boxes, gt_boxes, weights):
dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0] dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0]
dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1] dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1]
dw = (np.log(gt_w / ex_w)) / ex_w / weights[2] dw = (np.log(gt_w / ex_w)) / weights[2]
dh = (np.log(gt_h / ex_h)) / ex_h / weights[3] dh = (np.log(gt_h / ex_h)) / weights[3]
targets = np.vstack([dx, dy, dw, dh]).transpose() targets = np.vstack([dx, dy, dw, dh]).transpose()
return targets return targets
......
...@@ -61,7 +61,7 @@ class TestROIPoolOp(OpTest): ...@@ -61,7 +61,7 @@ class TestROIPoolOp(OpTest):
for i in range(self.rois_num): for i in range(self.rois_num):
roi = self.rois[i] roi = self.rois[i]
roi_batch_id = roi[0] roi_batch_id = int(roi[0])
roi_start_w = int(cpt.round(roi[1] * self.spatial_scale)) roi_start_w = int(cpt.round(roi[1] * self.spatial_scale))
roi_start_h = int(cpt.round(roi[2] * self.spatial_scale)) roi_start_h = int(cpt.round(roi[2] * self.spatial_scale))
roi_end_w = int(cpt.round(roi[3] * self.spatial_scale)) roi_end_w = int(cpt.round(roi[3] * self.spatial_scale))
...@@ -125,7 +125,7 @@ class TestROIPoolOp(OpTest): ...@@ -125,7 +125,7 @@ class TestROIPoolOp(OpTest):
roi = [bno, x1, y1, x2, y2] roi = [bno, x1, y1, x2, y2]
rois.append(roi) rois.append(roi)
self.rois_num = len(rois) self.rois_num = len(rois)
self.rois = np.array(rois).astype("int64") self.rois = np.array(rois).astype("float32")
def setUp(self): def setUp(self):
self.op_type = "roi_pool" self.op_type = "roi_pool"
......
...@@ -18,12 +18,17 @@ import unittest ...@@ -18,12 +18,17 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
from test_anchor_generator_op import anchor_generator_in_python
from test_generate_proposal_labels import _generate_groundtruth
from test_generate_proposal_labels import _bbox_overlaps, _box_to_delta
def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap, def rpn_target_assign(gt_anchor_iou, rpn_batch_size_per_im,
rpn_negative_overlap, fg_fraction): rpn_positive_overlap, rpn_negative_overlap, fg_fraction):
iou = np.transpose(iou) iou = np.transpose(gt_anchor_iou)
anchor_to_gt_max = iou.max(axis=1) anchor_to_gt_max = iou.max(axis=1)
anchor_to_gt_argmax = iou.argmax(axis=1)
gt_to_anchor_argmax = iou.argmax(axis=0) gt_to_anchor_argmax = iou.argmax(axis=0)
gt_to_anchor_max = iou[gt_to_anchor_argmax, np.arange(iou.shape[1])] gt_to_anchor_max = iou[gt_to_anchor_argmax, np.arange(iou.shape[1])]
anchors_with_max_overlap = np.where(iou == gt_to_anchor_max)[0] anchors_with_max_overlap = np.where(iou == gt_to_anchor_max)[0]
...@@ -42,59 +47,113 @@ def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap, ...@@ -42,59 +47,113 @@ def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap,
num_bg = rpn_batch_size_per_im - np.sum(tgt_lbl == 1) num_bg = rpn_batch_size_per_im - np.sum(tgt_lbl == 1)
bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0] bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0]
tgt_lbl[bg_inds] = 0
if len(bg_inds) > num_bg: if len(bg_inds) > num_bg:
enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)] enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
tgt_lbl[enable_inds] = 0 tgt_lbl[enable_inds] = 0
bg_inds = np.where(tgt_lbl == 0)[0] bg_inds = np.where(tgt_lbl == 0)[0]
tgt_lbl[bg_inds] = 0
loc_index = fg_inds loc_index = fg_inds
score_index = np.hstack((fg_inds, bg_inds)) score_index = np.hstack((fg_inds, bg_inds))
tgt_lbl = np.expand_dims(tgt_lbl, axis=1) tgt_lbl = np.expand_dims(tgt_lbl, axis=1)
return loc_index, score_index, tgt_lbl
gt_inds = anchor_to_gt_argmax[fg_inds]
return loc_index, score_index, tgt_lbl, gt_inds
def get_anchor(n, c, h, w):
input_feat = np.random.random((n, c, h, w)).astype('float32')
anchors, _ = anchor_generator_in_python(
input_feat=input_feat,
anchor_sizes=[32., 64.],
aspect_ratios=[0.5, 1.0],
variances=[1.0, 1.0, 1.0, 1.0],
stride=[16.0, 16.0],
offset=0.5)
return anchors
def rpn_blob(anchor, gt_boxes, iou, lod, rpn_batch_size_per_im,
rpn_positive_overlap, rpn_negative_overlap, fg_fraction):
loc_indexes = []
score_indexes = []
tmp_tgt_labels = []
tgt_bboxes = []
anchor_num = anchor.shape[0]
batch_size = len(lod) - 1
for i in range(batch_size):
b, e = lod[i], lod[i + 1]
iou_slice = iou[b:e, :]
bboxes_slice = gt_boxes[b:e, :]
loc_idx, score_idx, tgt_lbl, gt_inds = rpn_target_assign(
iou_slice, rpn_batch_size_per_im, rpn_positive_overlap,
rpn_negative_overlap, fg_fraction)
fg_bboxes = bboxes_slice[gt_inds]
fg_anchors = anchor[loc_idx]
box_deltas = _box_to_delta(fg_anchors, fg_bboxes, [1., 1., 1., 1.])
if i == 0:
loc_indexes = loc_idx
score_indexes = score_idx
tmp_tgt_labels = tgt_lbl
tgt_bboxes = box_deltas
else:
loc_indexes = np.concatenate(
[loc_indexes, loc_idx + i * anchor_num])
score_indexes = np.concatenate(
[score_indexes, score_idx + i * anchor_num])
tmp_tgt_labels = np.concatenate([tmp_tgt_labels, tgt_lbl])
tgt_bboxes = np.vstack([tgt_bboxes, box_deltas])
tgt_labels = tmp_tgt_labels[score_indexes]
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels
class TestRpnTargetAssignOp(OpTest): class TestRpnTargetAssignOp(OpTest):
def setUp(self): def setUp(self):
iou = np.random.random((10, 8)).astype("float32") n, c, h, w = 2, 4, 14, 14
self.op_type = "rpn_target_assign" anchor = get_anchor(n, c, h, w)
self.inputs = {'DistMat': iou} gt_num = 10
self.attrs = { anchor = anchor.reshape(-1, 4)
'rpn_batch_size_per_im': 256, anchor_num = anchor.shape[0]
'rpn_positive_overlap': 0.95,
'rpn_negative_overlap': 0.3,
'fg_fraction': 0.25,
'fix_seed': True
}
loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 256, 0.95, 0.3,
0.25)
self.outputs = {
'LocationIndex': loc_index,
'ScoreIndex': score_index,
'TargetLabel': tgt_lbl,
}
def test_check_output(self): im_shapes = [[64, 64], [64, 64]]
self.check_output() gt_box, lod = _generate_groundtruth(im_shapes, 3, 4)
bbox = np.vstack([v['boxes'] for v in gt_box])
iou = _bbox_overlaps(bbox, anchor)
anchor = anchor.astype('float32')
bbox = bbox.astype('float32')
iou = iou.astype('float32')
loc_index, score_index, tgt_bbox, tgt_lbl = rpn_blob(
anchor, bbox, iou, [0, 4, 8], 25600, 0.95, 0.03, 0.25)
class TestRpnTargetAssignOp2(OpTest):
def setUp(self):
iou = np.random.random((10, 20)).astype("float32")
self.op_type = "rpn_target_assign" self.op_type = "rpn_target_assign"
self.inputs = {'DistMat': iou} self.inputs = {
'Anchor': anchor,
'GtBox': (bbox, [[4, 4]]),
'DistMat': (iou, [[4, 4]]),
}
self.attrs = { self.attrs = {
'rpn_batch_size_per_im': 128, 'rpn_batch_size_per_im': 25600,
'rpn_positive_overlap': 0.5, 'rpn_positive_overlap': 0.95,
'rpn_negative_overlap': 0.5, 'rpn_negative_overlap': 0.03,
'fg_fraction': 0.5, 'fg_fraction': 0.25,
'fix_seed': True 'fix_seed': True
} }
loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 128, 0.5, 0.5,
0.5)
self.outputs = { self.outputs = {
'LocationIndex': loc_index, 'LocationIndex': loc_index.astype('int32'),
'ScoreIndex': score_index, 'ScoreIndex': score_index.astype('int32'),
'TargetLabel': tgt_lbl, 'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': tgt_lbl.astype('int64'),
} }
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册