yolo_box_op.cu 5.0 KB
Newer Older
D
dengkaipeng 已提交
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
D
dengkaipeng 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/memory/malloc.h"
D
dengkaipeng 已提交
16
#include "paddle/fluid/operators/detection/yolo_box_op.h"
D
dengkaipeng 已提交
17
#include "paddle/fluid/operators/math/math_function.h"
D
dengkaipeng 已提交
18 19 20 21 22 23 24

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
D
dengkaipeng 已提交
25
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
D
dengkaipeng 已提交
26 27 28
                            T* scores, const float conf_thresh,
                            const int* anchors, const int n, const int h,
                            const int w, const int an_num, const int class_num,
29 30 31
                            const int box_num, int input_size_h,
                            int input_size_w, bool clip_bbox, const float scale,
                            const float bias) {
D
dengkaipeng 已提交
32 33
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
D
dengkaipeng 已提交
34
  T box[4];
35
  for (; tid < n * box_num; tid += stride) {
D
dengkaipeng 已提交
36 37 38 39 40 41
    int grid_num = h * w;
    int i = tid / box_num;
    int j = (tid % box_num) / grid_num;
    int k = (tid % grid_num) / w;
    int l = tid % w;

42
    int an_stride = (5 + class_num) * grid_num;
D
dengkaipeng 已提交
43 44 45 46 47 48 49 50 51 52 53 54
    int img_height = imgsize[2 * i];
    int img_width = imgsize[2 * i + 1];

    int obj_idx =
        GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4);
    T conf = sigmoid<T>(input[obj_idx]);
    if (conf < conf_thresh) {
      continue;
    }

    int box_idx =
        GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
55 56 57
    GetYoloBox<T>(box, input, anchors, l, k, j, h, w, input_size_h,
                  input_size_w, box_idx, grid_num, img_height, img_width, scale,
                  bias);
D
dengkaipeng 已提交
58
    box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
59
    CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
D
dengkaipeng 已提交
60 61 62

    int label_idx =
        GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
63
    int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
D
dengkaipeng 已提交
64 65 66
    CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
                      grid_num);
  }
D
dengkaipeng 已提交
67 68 69 70 71 72
}

template <typename T>
class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
D
dengkaipeng 已提交
73
    auto* input = ctx.Input<Tensor>("X");
D
dengkaipeng 已提交
74
    auto* img_size = ctx.Input<Tensor>("ImgSize");
D
dengkaipeng 已提交
75 76 77 78 79 80 81
    auto* boxes = ctx.Output<Tensor>("Boxes");
    auto* scores = ctx.Output<Tensor>("Scores");

    auto anchors = ctx.Attr<std::vector<int>>("anchors");
    int class_num = ctx.Attr<int>("class_num");
    float conf_thresh = ctx.Attr<float>("conf_thresh");
    int downsample_ratio = ctx.Attr<int>("downsample_ratio");
82
    bool clip_bbox = ctx.Attr<bool>("clip_bbox");
83 84
    float scale = ctx.Attr<float>("scale_x_y");
    float bias = -0.5 * (scale - 1.);
D
dengkaipeng 已提交
85 86 87 88 89 90

    const int n = input->dims()[0];
    const int h = input->dims()[2];
    const int w = input->dims()[3];
    const int box_num = boxes->dims()[1];
    const int an_num = anchors.size() / 2;
91 92
    int input_size_h = downsample_ratio * h;
    int input_size_w = downsample_ratio * w;
D
dengkaipeng 已提交
93

D
dengkaipeng 已提交
94 95
    auto& dev_ctx = ctx.cuda_device_context();
    int bytes = sizeof(int) * anchors.size();
96
    auto anchors_ptr = memory::Alloc(dev_ctx, sizeof(int) * anchors.size());
D
dengkaipeng 已提交
97
    int* anchors_data = reinterpret_cast<int*>(anchors_ptr->ptr());
98
    const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
D
dengkaipeng 已提交
99 100
    const auto cplace = platform::CPUPlace();
    memory::Copy(gplace, anchors_data, cplace, anchors.data(), bytes,
D
dengkaipeng 已提交
101
                 dev_ctx.stream());
D
dengkaipeng 已提交
102

D
dengkaipeng 已提交
103
    const T* input_data = input->data<T>();
D
dengkaipeng 已提交
104
    const int* imgsize_data = img_size->data<int>();
D
dengkaipeng 已提交
105 106 107
    T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
    T* scores_data =
        scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
D
dengkaipeng 已提交
108 109 110
    math::SetConstant<platform::CUDADeviceContext, T> set_zero;
    set_zero(dev_ctx, boxes, static_cast<T>(0));
    set_zero(dev_ctx, scores, static_cast<T>(0));
D
dengkaipeng 已提交
111

112 113
    int grid_dim = (n * box_num + 512 - 1) / 512;
    grid_dim = grid_dim > 8 ? 8 : grid_dim;
D
dengkaipeng 已提交
114

115
    KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
D
dengkaipeng 已提交
116
        input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
117 118
        anchors_data, n, h, w, an_num, class_num, box_num, input_size_h,
        input_size_w, clip_bbox, scale, bias);
D
dengkaipeng 已提交
119
  }
D
dengkaipeng 已提交
120
};
D
dengkaipeng 已提交
121 122 123 124 125

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
D
dengkaipeng 已提交
126
REGISTER_OP_CUDA_KERNEL(yolo_box, ops::YoloBoxOpCUDAKernel<float>,
D
dengkaipeng 已提交
127
                        ops::YoloBoxOpCUDAKernel<double>);