提交 593ad763 编写于 作者: Y Yu Yang

refactor(op): polish generate_proposals_op

Polish styles in generate_proposals_op.

1. inline lambda functions rathar than use std::function to save var.
2. add `static inline` to template functions .cc
   * Make them static to prevent generating symbols.
   * Make them inline to give compiler a hit inline them as possible.
   * Not if the function is not static, they cannot be inlined since the
     symbols should be exported.
3. add `static` to global functions in .cc
   * Make them static to prevent generating symbols.
4. Use Vector<uint64> instead manually manange storage between devices.
5. Prefer to use platform::ForRange, so we can optimize `ForRange` by
   just changing `for_range.h` if it is needed.
6. Do not change shape of inputs

test=develop
上级 7a5f3f75
......@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -25,21 +27,17 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
struct AppendProposalsFunctor {
LoDTensor *out_;
int64_t offset_;
Tensor *to_add_;
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
AppendProposalsFunctor(LoDTensor *out, int64_t offset, Tensor *to_add)
: out_(out), offset_(offset), to_add_(to_add) {}
template <typename T>
void apply() const {
auto *out_data = out_->data<T>();
auto *to_add_data = to_add_->data<T>();
memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T));
}
};
static void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) {
auto *out_data = dst->data<void>();
auto *to_add_data = src.data<void>();
size_t size_of_t = framework::SizeOfType(src.type());
offset *= size_of_t;
std::memcpy(
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(out_data) + offset),
to_add_data, src.numel() * size_of_t);
}
class GenerateProposalsOp : public framework::OperatorWithKernel {
public:
......@@ -75,8 +73,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
};
template <class T>
void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
Tensor *bbox_deltas, Tensor *variances, Tensor *proposals) {
static inline void BoxCoder(const platform::DeviceContext &ctx,
Tensor *all_anchors, Tensor *bbox_deltas,
Tensor *variances, Tensor *proposals) {
T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
int64_t row = all_anchors->dims()[0];
......@@ -108,11 +107,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
anchor_center_y;
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
kBBoxClipDefault)) *
anchor_height;
} else {
bbox_center_x =
......@@ -120,10 +119,10 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
kBBoxClipDefault)) *
anchor_height;
}
......@@ -136,30 +135,32 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
}
template <class T>
void ClipTiledBoxes(const platform::DeviceContext &ctx, const Tensor &im_info,
Tensor *boxes) {
static inline void ClipTiledBoxes(const platform::DeviceContext &ctx,
const Tensor &im_info, Tensor *boxes) {
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
const T *im_info_data = im_info.data<T>();
T zero(0);
for (int64_t i = 0; i < boxes->numel(); ++i) {
if (i % 4 == 0) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else if (i % 4 == 1) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
} else if (i % 4 == 2) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
}
}
}
template <class T>
void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
float min_size, const Tensor &im_info, Tensor *keep) {
static inline void FilterBoxes(const platform::DeviceContext &ctx,
Tensor *boxes, float min_size,
const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
T im_scale = im_info_data[2];
......@@ -185,24 +186,24 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
keep->Resize({keep_len});
}
bool SortScorePairDescend(const std::pair<float, int> &pair1,
const std::pair<float, int> &pair2) {
return pair1.first > pair2.first;
}
template <class T>
void GetMaxScoreIndex(const std::vector<T> &scores,
std::vector<std::pair<T, int>> *sorted_indices) {
static inline std::vector<std::pair<T, int>> GetSortedScoreIndex(
const std::vector<T> &scores) {
std::vector<std::pair<T, int>> sorted_indices;
sorted_indices.reserve(scores.size());
for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices->push_back(std::make_pair(scores[i], i));
sorted_indices.emplace_back(scores[i], i);
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescend);
std::stable_sort(sorted_indices.begin(), sorted_indices.end(),
[](const std::pair<T, int> &a, const std::pair<T, int> &b) {
return a.first < b.first;
});
return sorted_indices;
}
template <class T>
T BBoxArea(const T *box, const bool normalized) {
static inline T BBoxArea(const T *box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
......@@ -220,7 +221,7 @@ T BBoxArea(const T *box, const bool normalized) {
}
template <class T>
T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
static inline T JaccardOverlap(const T *box1, const T *box2, bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
......@@ -229,8 +230,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = std::max(0.0f, inter_xmax - inter_xmin + 1);
const T inter_h = std::max(0.0f, inter_ymax - inter_ymin + 1);
const T inter_w = std::max(T(0), inter_xmax - inter_xmin + 1);
const T inter_h = std::max(T(0), inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
......@@ -238,9 +239,21 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
}
}
template <typename T>
static inline Tensor VectorToTensor(const std::vector<T> &selected_indices,
int selected_num) {
Tensor keep_nms;
keep_nms.Resize({selected_num});
auto *keep_data = keep_nms.mutable_data<T>(platform::CPUPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <class T>
Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
const T nms_threshold, const float eta) {
static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
Tensor *scores, T nms_threshold, float eta) {
PADDLE_ENFORCE_NOT_NULL(bbox);
int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax]
......@@ -248,20 +261,18 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex<T>(scores_data, &sorted_indices);
std::vector<std::pair<T, int>> sorted_indices =
GetSortedScoreIndex<T>(scores_data);
std::vector<int> selected_indices;
int selected_num = 0;
T adaptive_threshold = nms_threshold;
const T *bbox_data = bbox->data<T>();
bool flag;
while (sorted_indices.size() != 0) {
int idx = sorted_indices.front().second;
flag = true;
for (size_t k = 0; k < selected_indices.size(); ++k) {
int idx = sorted_indices.back().second;
bool flag = true;
for (int kept_idx : selected_indices) {
if (flag) {
const int kept_idx = selected_indices[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold);
......@@ -271,32 +282,29 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
}
if (flag) {
selected_indices.push_back(idx);
selected_num++;
++selected_num;
}
sorted_indices.erase(sorted_indices.begin());
sorted_indices.erase(sorted_indices.end());
if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
Tensor keep_nms;
keep_nms.Resize({selected_num});
int *keep_data = keep_nms.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
return VectorToTensor(selected_indices, selected_num);
}
template <typename DeviceContext, typename T>
template <typename T>
class GenerateProposalsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors");
auto *variances = context.Input<Tensor>("Variances");
auto anchors = detail::Ref(context.Input<Tensor>("Anchors"),
"Cannot find input Anchors(%s) in scope",
context.Inputs("Anchors")[0]);
auto variances = detail::Ref(context.Input<Tensor>("Variances"),
"Cannot find input Variances(%s) in scope",
context.Inputs("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
......@@ -307,15 +315,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
auto &dev_ctx = context.template device_context<DeviceContext>();
auto &dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto scores_dim = scores->dims();
auto &scores_dim = scores->dims();
int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims();
auto &bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];
......@@ -330,17 +339,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());
math::Transpose<DeviceContext, T, 4> trans;
math::Transpose<platform::CPUDeviceContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);
framework::LoD lod;
std::vector<size_t> lod0(1, 0);
Tensor *anchor = const_cast<framework::Tensor *>(anchors);
anchor->Resize({anchors->numel() / 4, 4});
Tensor *var = const_cast<framework::Tensor *>(variances);
var->Resize({var->numel() / 4, 4});
lod.resize(1);
auto &lod0 = lod[0];
lod0.push_back(0);
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) {
......@@ -352,24 +361,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair =
ProposalForOneImage(dev_ctx, im_info_slice, *anchor, *var,
ProposalForOneImage(dev_ctx, im_info_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = tensor_pair.first;
Tensor scores = tensor_pair.second;
framework::VisitDataType(
framework::ToDataType(rpn_rois->type()),
AppendProposalsFunctor(rpn_rois, 4 * num_proposals, &proposals));
framework::VisitDataType(
framework::ToDataType(rpn_roi_probs->type()),
AppendProposalsFunctor(rpn_roi_probs, num_proposals, &scores));
Tensor &proposals = tensor_pair.first;
Tensor &scores = tensor_pair.second;
AppendProposals(rpn_rois, 4 * num_proposals, proposals);
AppendProposals(rpn_roi_probs, num_proposals, scores);
num_proposals += proposals.dims()[0];
lod0.emplace_back(num_proposals);
lod0.push_back(num_proposals);
}
lod.emplace_back(lod0);
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
......@@ -377,7 +379,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
}
std::pair<Tensor, Tensor> ProposalForOneImage(
const DeviceContext &ctx, const Tensor &im_info_slice,
const platform::CPUDeviceContext &ctx, const Tensor &im_info_slice,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1]
......@@ -392,10 +394,9 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i;
}
std::function<bool(const int64_t &, const int64_t &)> compare =
[scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};
auto compare = [scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};
if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
std::sort(index, index + scores_slice.numel(), compare);
......@@ -469,12 +470,12 @@ class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
Generate Proposals OP
This operator proposes rois according to each box with their probability to be a foreground object and
the box can be calculated by anchors. Bbox_deltais and scores are the output of RPN. Final proposals
the box can be calculated by anchors. Bbox_details and scores are the output of RPN. Final proposals
could be used to train detection net.
Scores is the probability for each box to be an object. In format of (N, A, H, W) where N is batch size, A is number
of anchors, H and W are height and width of the feature map.
BboxDeltas is the differece between predicted box locatoin and anchor location. In format of (N, 4*A, H, W)
BboxDeltas is the differece between predicted box location and anchor location. In format of (N, 4*A, H, W)
For generating proposals, this operator transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) and
calculate box locations as proposals candidates. Then clip boxes to image and remove predicted boxes with small area.
......@@ -490,6 +491,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp,
ops::GenerateProposalsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
generate_proposals,
ops::GenerateProposalsKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>,
ops::GenerateProposalsKernel<double>);
......@@ -16,10 +16,13 @@ limitations under the License. */
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
......@@ -36,36 +39,38 @@ namespace {
int const kThreadsPerBlock = sizeof(uint64_t) * 8;
template <typename T>
__global__ void RangeInitKernel(const T start, const T delta, const int size,
T *out) {
CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
}
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
struct RangeInitFunctor {
int start_;
int delta_;
int *out_;
__device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};
template <typename T>
void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value,
Tensor *value_out, Tensor *index_out) {
int num = value.numel();
static void SortDescending(const platform::CUDADeviceContext &ctx,
const Tensor &value, Tensor *value_out,
Tensor *index_out) {
int num = static_cast<int>(value.numel());
Tensor index_in_t;
int *idx_in = index_in_t.mutable_data<int>({num}, ctx.GetPlace());
int block = 512;
auto stream = ctx.stream();
RangeInitKernel<<<DIVUP(num, block), block, 0, stream>>>(0, 1, num, idx_in);
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, num);
for_range(RangeInitFunctor{0, 1, idx_in});
int *idx_out = index_out->mutable_data<int>({num}, ctx.GetPlace());
const T *keys_in = value.data<T>();
T *keys_out = value_out->mutable_data<T>({num}, ctx.GetPlace());
// Determine temporary device storage requirements
void *d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out,
num);
nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num);
// Allocate temporary storage
auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
void *d_temp_storage = memory::Alloc(place, temp_storage_bytes);
// Run sorting operation
cub::DeviceRadixSort::SortPairsDescending<T, int>(
......@@ -76,22 +81,27 @@ void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value,
}
template <typename T>
__device__ __forceinline__ T Min(T x, T y) {
return x < y ? x : y;
}
template <typename T>
__device__ __forceinline__ T Max(T x, T y) {
return x > y ? x : y;
}
template <typename T>
__global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
const T *var, const int *index,
const T *im_info, const int num,
T *proposals) {
T kBBoxClipDefault = log(1000.0 / 16.0);
CUDA_1D_KERNEL_LOOP(i, num) {
struct BoxDecodeAndClipFunctor {
const T *anchor;
const T *deltas;
const T *var;
const int *index;
const T *im_info;
T *proposals;
BoxDecodeAndClipFunctor(const T *anchor, const T *deltas, const T *var,
const int *index, const T *im_info, T *proposals)
: anchor(anchor),
deltas(deltas),
var(var),
index(index),
im_info(im_info),
proposals(proposals) {}
T bbox_clip_default{static_cast<T>(kBBoxClipDefault)};
__device__ void operator()(size_t i) {
int k = index[i] * 4;
T axmin = anchor[k];
T aymin = anchor[k + 1];
......@@ -108,17 +118,17 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T dxmax = deltas[k + 2];
T dymax = deltas[k + 3];
T d_cx = 0., d_cy = 0., d_w = 0., d_h = 0.;
T d_cx, d_cy, d_w, d_h;
if (var) {
d_cx = cx + dxmin * w * var[k];
d_cy = cy + dymin * h * var[k + 1];
d_w = exp(Min<T>(dxmax * var[k + 2], kBBoxClipDefault)) * w;
d_h = exp(Min<T>(dymax * var[k + 3], kBBoxClipDefault)) * h;
d_w = exp(Min(dxmax * var[k + 2], bbox_clip_default)) * w;
d_h = exp(Min(dymax * var[k + 3], bbox_clip_default)) * h;
} else {
d_cx = cx + dxmin * w;
d_cy = cy + dymin * h;
d_w = exp(Min<T>(dxmax, kBBoxClipDefault)) * w;
d_h = exp(Min<T>(dymax, kBBoxClipDefault)) * h;
d_w = exp(Min(dxmax, bbox_clip_default)) * w;
d_h = exp(Min(dymax, bbox_clip_default)) * h;
}
T oxmin = d_cx - d_w * 0.5;
......@@ -126,17 +136,21 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T oxmax = d_cx + d_w * 0.5 - 1.;
T oymax = d_cy + d_h * 0.5 - 1.;
proposals[i * 4] = Max<T>(Min<T>(oxmin, im_info[1] - 1.), 0.);
proposals[i * 4 + 1] = Max<T>(Min<T>(oymin, im_info[0] - 1.), 0.);
proposals[i * 4 + 2] = Max<T>(Min<T>(oxmax, im_info[1] - 1.), 0.);
proposals[i * 4 + 3] = Max<T>(Min<T>(oymax, im_info[0] - 1.), 0.);
proposals[i * 4] = Max(Min(oxmin, im_info[1] - 1.), 0.);
proposals[i * 4 + 1] = Max(Min(oymin, im_info[0] - 1.), 0.);
proposals[i * 4 + 2] = Max(Min(oxmax, im_info[1] - 1.), 0.);
proposals[i * 4 + 3] = Max(Min(oymax, im_info[0] - 1.), 0.);
}
}
__device__ __forceinline__ T Min(T a, T b) const { return a > b ? b : a; }
__device__ __forceinline__ T Max(T a, T b) const { return a > b ? a : b; }
};
template <typename T, int BlockSize>
__global__ void FilterBBoxes(const T *bboxes, const T *im_info,
const T min_size, const int num, int *keep_num,
int *keep) {
static __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
const T min_size, const int num,
int *keep_num, int *keep) {
T im_h = im_info[0];
T im_w = im_info[1];
T im_scale = im_info[2];
......@@ -181,7 +195,7 @@ __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
}
}
__device__ inline float IoU(const float *a, const float *b) {
static __device__ inline float IoU(const float *a, const float *b) {
float left = max(a[0], b[0]), right = min(a[2], b[2]);
float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
......@@ -191,8 +205,9 @@ __device__ inline float IoU(const float *a, const float *b) {
return inter_s / (s_a + s_b - inter_s);
}
__global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh,
const float *dev_boxes, uint64_t *dev_mask) {
static __global__ void NMSKernel(const int n_boxes,
const float nms_overlap_thresh,
const float *dev_boxes, uint64_t *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
......@@ -234,9 +249,9 @@ __global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh,
}
template <typename T>
void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const Tensor &sorted_indices, const T nms_threshold,
Tensor *keep_out) {
static void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const Tensor &sorted_indices, const T nms_threshold,
Tensor *keep_out) {
int boxes_num = proposals.dims()[0];
PADDLE_ENFORCE_EQ(boxes_num, sorted_indices.dims()[0]);
......@@ -247,13 +262,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const T *boxes = proposals.data<T>();
auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
int size_bytes = boxes_num * col_blocks * sizeof(uint64_t);
uint64_t *d_mask =
reinterpret_cast<uint64_t *>(memory::Alloc(place, size_bytes));
NMSKernel<<<blocks, threads>>>(boxes_num, nms_threshold, boxes, d_mask);
uint64_t *h_mask = reinterpret_cast<uint64_t *>(
memory::Alloc(platform::CPUPlace(), size_bytes));
memory::Copy(platform::CPUPlace(), h_mask, place, d_mask, size_bytes, 0);
framework::Vector<uint64_t> mask(boxes_num * col_blocks);
NMSKernel<<<blocks, threads>>>(
boxes_num, nms_threshold, boxes,
mask.CUDAMutableData(boost::get<platform::CUDAPlace>(ctx.GetPlace())));
std::vector<uint64_t> remv(col_blocks);
memset(&remv[0], 0, sizeof(uint64_t) * col_blocks);
......@@ -267,7 +279,7 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
if (!(remv[nblock] & (1ULL << inblock))) {
++num_to_keep;
keep_vec.push_back(i);
uint64_t *p = &h_mask[0] + i * col_blocks;
uint64_t *p = &mask[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
......@@ -276,12 +288,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace());
memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(),
sizeof(int) * num_to_keep, 0);
memory::Free(place, d_mask);
memory::Free(platform::CPUPlace(), h_mask);
}
template <typename T>
std::pair<Tensor, Tensor> ProposalForOneImage(
static std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::CUDADeviceContext &ctx, const Tensor &im_info,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas, // [M, 4]
......@@ -300,18 +310,20 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
// 2. box decode and clipping
Tensor proposals;
proposals.mutable_data<T>({pre_nms_num, 4}, ctx.GetPlace());
int block = 512;
auto stream = ctx.stream();
BoxDecodeAndClipKernel<T><<<DIVUP(pre_nms_num, block), block, 0, stream>>>(
anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(),
index_sort.data<int>(), im_info.data<T>(), pre_nms_num,
proposals.data<T>());
{
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, pre_nms_num);
for_range(BoxDecodeAndClipFunctor<T>{
anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(),
index_sort.data<int>(), im_info.data<T>(), proposals.data<T>()});
}
// 3. filter
Tensor keep_index, keep_num_t;
keep_index.mutable_data<int>({pre_nms_num}, ctx.GetPlace());
keep_num_t.mutable_data<int>({1}, ctx.GetPlace());
min_size = std::max(min_size, 1.0f);
auto stream = ctx.stream();
FilterBBoxes<T, 512><<<1, 512, 0, stream>>>(
proposals.data<T>(), im_info.data<T>(), min_size, pre_nms_num,
keep_num_t.data<int>(), keep_index.data<int>());
......@@ -355,8 +367,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors");
auto *variances = context.Input<Tensor>("Variances");
auto anchors = detail::Ref(context.Input<Tensor>("Anchors"),
"Cannot find input Anchors(%s) in scope",
context.Inputs("Anchors")[0]);
auto variances = detail::Ref(context.Input<Tensor>("Variances"),
"Cannot find input Variances(%s) in scope",
context.Inputs("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
......@@ -392,10 +408,8 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);
Tensor *anchor = const_cast<framework::Tensor *>(anchors);
anchor->Resize({anchors->numel() / 4, 4});
Tensor *var = const_cast<framework::Tensor *>(variances);
var->Resize({var->numel() / 4, 4});
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
......@@ -404,7 +418,7 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
T *rpn_rois_data = rpn_rois->data<T>();
T *rpn_roi_probs_data = rpn_roi_probs->data<T>();
auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
auto &place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
int64_t num_proposals = 0;
std::vector<size_t> offset(1, 0);
......@@ -417,12 +431,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> box_score_pair =
ProposalForOneImage<T>(dev_ctx, im_info_slice, *anchor, *var,
ProposalForOneImage<T>(dev_ctx, im_info_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = box_score_pair.first;
Tensor scores = box_score_pair.second;
Tensor &proposals = box_score_pair.first;
Tensor &scores = box_score_pair.second;
memory::Copy(place, rpn_rois_data + num_proposals * 4, place,
proposals.data<T>(), sizeof(T) * proposals.numel(), 0);
......
......@@ -39,11 +39,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
int64_t index_size = index.dims()[0];
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
......@@ -55,7 +53,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
for (int64_t i = 0; i < index_size; ++i) {
int index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册