提交 04b8b9e9 编写于 作者: D dengkaipeng

add yolo_box_op CUDA kernel

上级 452373de
...@@ -35,7 +35,6 @@ class YoloBoxOp : public framework::OperatorWithKernel { ...@@ -35,7 +35,6 @@ class YoloBoxOp : public framework::OperatorWithKernel {
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors"); auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2; int anchor_num = anchors.size() / 2;
auto class_num = ctx->Attrs().Get<int>("class_num"); auto class_num = ctx->Attrs().Get<int>("class_num");
auto conf_thresh = ctx->Attrs().Get<float>("conf_thresh");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -20,15 +20,44 @@ namespace operators { ...@@ -20,15 +20,44 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
static __global__ void GenDensityPriorBox( __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
const int height, const int width, const int im_height, const int im_width, T* scores, const float conf_thresh,
const T offset, const T step_width, const T step_height, std::vector<int> anchors, const int h, const in w,
const int num_priors, const T* ratios_shift, bool is_clip, const T var_xmin, const int an_num, const int class_num,
const T var_ymin, const T var_xmax, const T var_ymax, T* out, T* var) { const int box_num, const int input_size) {
int gidx = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int gidy = blockIdx.y * blockDim.y + threadIdx.y; int stride = blockDim.x * gridDim.x;
int step_x = blockDim.x * gridDim.x; for (; tid < box_num; tid += stride) {
int step_y = blockDim.y * gridDim.y; int grid_num = h * w;
int i = tid / box_num;
int j = (tid % box_num) / grid_num;
int k = (tid % grid_num) / w;
int l = tid % w;
int an_stride = an_num * grid_num;
int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1];
int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4);
T conf = sigmoid<T>(input[obj_idx]);
if (conf < conf_thresh) {
continue;
}
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
Box<T> pred = GetYoloBox<T>(input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, pred, box_idx);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num);
}
} }
template <typename T> template <typename T>
...@@ -36,6 +65,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -36,6 +65,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input"); auto* input = ctx.Input<Tensor>("Input");
auto* img_size = ctx.Input<Tensor>("ImgSize");
auto* boxes = ctx.Output<Tensor>("Boxes"); auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores"); auto* scores = ctx.Output<Tensor>("Scores");
...@@ -51,14 +81,16 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -51,14 +81,16 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
const int an_num = anchors.size() / 2; const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h; int input_size = downsample_ratio * h;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* boxes_data = boxes->mutable_data<T>({n}, ctx.GetPlace()); const int* imgsize_data = imgsize->data<int>();
memset(loss_data, 0, boxes->numel() * sizeof(T)); T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
T* scores_data = scores->mutable_data<T>({n}, ctx.GetPlace()); memset(boxes_data, 0, boxes->numel() * sizeof(T));
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
memset(scores_data, 0, scores->numel() * sizeof(T)); memset(scores_data, 0, scores->numel() * sizeof(T));
int grid_dim = (n * box_num + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
} }
}; // namespace operators }; // namespace operators
......
...@@ -30,7 +30,7 @@ static inline T sigmoid(T x) { ...@@ -30,7 +30,7 @@ static inline T sigmoid(T x) {
} }
template <typename T> template <typename T>
static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i, HOSTDEVICE inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i,
int j, int an_idx, int grid_size, int j, int an_idx, int grid_size,
int input_size, int index, int stride, int input_size, int index, int stride,
int img_height, int img_width) { int img_height, int img_width) {
...@@ -44,13 +44,15 @@ static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i, ...@@ -44,13 +44,15 @@ static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i,
return b; return b;
} }
static inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num, HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
int an_stride, int stride, int entry) { int an_num, int an_stride, int stride,
int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
} }
template <typename T> template <typename T>
static inline void CalcDetectionBox(T* boxes, Box<T> pred, const int box_idx) { HOSTDEVICE inline void CalcDetectionBox(T* boxes, Box<T> pred,
const int box_idx) {
boxes[box_idx] = pred.x - pred.w / 2; boxes[box_idx] = pred.x - pred.w / 2;
boxes[box_idx + 1] = pred.y - pred.h / 2; boxes[box_idx + 1] = pred.y - pred.h / 2;
boxes[box_idx + 2] = pred.x + pred.w / 2; boxes[box_idx + 2] = pred.x + pred.w / 2;
...@@ -58,7 +60,7 @@ static inline void CalcDetectionBox(T* boxes, Box<T> pred, const int box_idx) { ...@@ -58,7 +60,7 @@ static inline void CalcDetectionBox(T* boxes, Box<T> pred, const int box_idx) {
} }
template <typename T> template <typename T>
static inline void CalcLabelScore(T* scores, const T* input, HOSTDEVICE inline void CalcLabelScore(T* scores, const T* input,
const int label_idx, const int score_idx, const int label_idx, const int score_idx,
const int class_num, const T conf, const int class_num, const T conf,
const int stride) { const int stride) {
...@@ -115,8 +117,8 @@ class YoloBoxKernel : public framework::OpKernel<T> { ...@@ -115,8 +117,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int box_idx = int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
Box<T> pred = Box<T> pred =
GetYoloBox(input_data, anchors, l, k, j, h, input_size, box_idx, GetYoloBox<T>(input_data, anchors, l, k, j, h, input_size,
stride, img_height, img_width); box_idx, stride, img_height, img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4; box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, pred, box_idx); CalcDetectionBox<T>(boxes_data, pred, box_idx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册