diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index f78a980674aa7075a953656a60ac7071d2f76273..c018a6498ae01f2bbb14604d3a783a8cff766b57 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -35,7 +35,6 @@ class YoloBoxOp : public framework::OperatorWithKernel { auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; auto class_num = ctx->Attrs().Get("class_num"); - auto conf_thresh = ctx->Attrs().Get("conf_thresh"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index 9cc94794f2f75367b95a87f8460334234e39583e..38b514fe90f218a7171cc13911cafafdff4af2cf 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -20,15 +20,44 @@ namespace operators { using Tensor = framework::Tensor; template -static __global__ void GenDensityPriorBox( - const int height, const int width, const int im_height, const int im_width, - const T offset, const T step_width, const T step_height, - const int num_priors, const T* ratios_shift, bool is_clip, const T var_xmin, - const T var_ymin, const T var_xmax, const T var_ymax, T* out, T* var) { - int gidx = blockIdx.x * blockDim.x + threadIdx.x; - int gidy = blockIdx.y * blockDim.y + threadIdx.y; - int step_x = blockDim.x * gridDim.x; - int step_y = blockDim.y * gridDim.y; +__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, + T* scores, const float conf_thresh, + std::vector anchors, const int h, const in w, + const int an_num, const int class_num, + const int box_num, const int input_size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < box_num; tid += stride) { + int grid_num = h * w; + int i = tid / box_num; + int j = (tid % box_num) / grid_num; + int k = (tid % grid_num) / w; + int l = tid % w; + + int an_stride = an_num * grid_num; + int img_height = imgsize[2 * i]; + int img_width = imgsize[2 * i + 1]; + + int obj_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4); + T conf = sigmoid(input[obj_idx]); + if (conf < conf_thresh) { + continue; + } + + int box_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); + Box pred = GetYoloBox(input, anchors, l, k, j, h, input_size, box_idx, + grid_num, img_height, img_width); + box_idx = (i * box_num + j * grid_num + k * w + l) * 4; + CalcDetectionBox(boxes, pred, box_idx); + + int label_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); + int score_idx = (i * box_num + j * stride + k * w + l) * class_num; + CalcLabelScore(scores, input, label_idx, score_idx, class_num, conf, + grid_num); + } } template @@ -36,6 +65,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); + auto* img_size = ctx.Input("ImgSize"); auto* boxes = ctx.Output("Boxes"); auto* scores = ctx.Output("Scores"); @@ -51,14 +81,16 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { const int an_num = anchors.size() / 2; int input_size = downsample_ratio * h; - const int stride = h * w; - const int an_stride = (class_num + 5) * stride; - const T* input_data = input->data(); - T* boxes_data = boxes->mutable_data({n}, ctx.GetPlace()); - memset(loss_data, 0, boxes->numel() * sizeof(T)); - T* scores_data = scores->mutable_data({n}, ctx.GetPlace()); + const int* imgsize_data = imgsize->data(); + T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); + memset(boxes_data, 0, boxes->numel() * sizeof(T)); + T* scores_data = + scores->mutable_data({n, box_num, class_num}, ctx.GetPlace()); memset(scores_data, 0, scores->numel() * sizeof(T)); + + int grid_dim = (n * box_num + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; } }; // namespace operators diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index 0ea8c178616ba8685f0d2c2fb6786ba8d34e4a4c..90933e123e098407dfc06124c4f2bc3cee9cb12a 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -30,10 +30,10 @@ static inline T sigmoid(T x) { } template -static inline Box GetYoloBox(const T* x, std::vector anchors, int i, - int j, int an_idx, int grid_size, - int input_size, int index, int stride, - int img_height, int img_width) { +HOSTDEVICE inline Box GetYoloBox(const T* x, std::vector anchors, int i, + int j, int an_idx, int grid_size, + int input_size, int index, int stride, + int img_height, int img_width) { Box b; b.x = (i + sigmoid(x[index])) * img_width / grid_size; b.y = (j + sigmoid(x[index + stride])) * img_height / grid_size; @@ -44,13 +44,15 @@ static inline Box GetYoloBox(const T* x, std::vector anchors, int i, return b; } -static inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num, - int an_stride, int stride, int entry) { +HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, + int an_num, int an_stride, int stride, + int entry) { return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; } template -static inline void CalcDetectionBox(T* boxes, Box pred, const int box_idx) { +HOSTDEVICE inline void CalcDetectionBox(T* boxes, Box pred, + const int box_idx) { boxes[box_idx] = pred.x - pred.w / 2; boxes[box_idx + 1] = pred.y - pred.h / 2; boxes[box_idx + 2] = pred.x + pred.w / 2; @@ -58,10 +60,10 @@ static inline void CalcDetectionBox(T* boxes, Box pred, const int box_idx) { } template -static inline void CalcLabelScore(T* scores, const T* input, - const int label_idx, const int score_idx, - const int class_num, const T conf, - const int stride) { +HOSTDEVICE inline void CalcLabelScore(T* scores, const T* input, + const int label_idx, const int score_idx, + const int class_num, const T conf, + const int stride) { for (int i = 0; i < class_num; i++) { scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]); } @@ -115,8 +117,8 @@ class YoloBoxKernel : public framework::OpKernel { int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); Box pred = - GetYoloBox(input_data, anchors, l, k, j, h, input_size, box_idx, - stride, img_height, img_width); + GetYoloBox(input_data, anchors, l, k, j, h, input_size, + box_idx, stride, img_height, img_width); box_idx = (i * box_num + j * stride + k * w + l) * 4; CalcDetectionBox(boxes_data, pred, box_idx);