diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index bc563107f8824b834216923f422f903f726be78f..fbe934c7eac8961e04f92cff89254446ff9d45b2 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -36,7 +36,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int k = (tid % grid_num) / w; int l = tid % w; - int an_stride = an_num * grid_num; + int an_stride = (5 + class_num) * grid_num; int img_height = imgsize[2 * i]; int img_width = imgsize[2 * i + 1]; @@ -56,7 +56,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); - int score_idx = (i * box_num + j * stride + k * w + l) * class_num; + int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; CalcLabelScore(scores, input, label_idx, score_idx, class_num, conf, grid_num); } @@ -99,12 +99,12 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { set_zero(dev_ctx, boxes, static_cast(0)); set_zero(dev_ctx, scores, static_cast(0)); - int grid_dim = (n * box_num + 512 - 1) / 512; - grid_dim = grid_dim > 8 ? 8 : grid_dim; + int grid_dim = (n * box_num + 4 - 1) / 4; + grid_dim = grid_dim > 2 ? 2 : grid_dim; - KeYoloBoxFw<<>>( + KeYoloBoxFw<<>>( input_data, imgsize_data, boxes_data, scores_data, conf_thresh, - anchors_data, h, w, an_num, class_num, box_num, input_size); + anchors_data, h, w, an_num, class_num, n * box_num, input_size); } };