From 7808f4c097cdac0eabd694f128dee4c93cd95788 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Tue, 19 Feb 2019 13:18:33 +0000 Subject: [PATCH] fix unittest for yolo_box_op. test=develop --- paddle/fluid/operators/detection/yolo_box_op.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index bc563107f88..fbe934c7eac 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -36,7 +36,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int k = (tid % grid_num) / w; int l = tid % w; - int an_stride = an_num * grid_num; + int an_stride = (5 + class_num) * grid_num; int img_height = imgsize[2 * i]; int img_width = imgsize[2 * i + 1]; @@ -56,7 +56,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); - int score_idx = (i * box_num + j * stride + k * w + l) * class_num; + int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; CalcLabelScore(scores, input, label_idx, score_idx, class_num, conf, grid_num); } @@ -99,12 +99,12 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { set_zero(dev_ctx, boxes, static_cast(0)); set_zero(dev_ctx, scores, static_cast(0)); - int grid_dim = (n * box_num + 512 - 1) / 512; - grid_dim = grid_dim > 8 ? 8 : grid_dim; + int grid_dim = (n * box_num + 4 - 1) / 4; + grid_dim = grid_dim > 2 ? 2 : grid_dim; - KeYoloBoxFw<<>>( + KeYoloBoxFw<<>>( input_data, imgsize_data, boxes_data, scores_data, conf_thresh, - anchors_data, h, w, an_num, class_num, box_num, input_size); + anchors_data, h, w, an_num, class_num, n * box_num, input_size); } }; -- GitLab