diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index fbe934c7eac8961e04f92cff89254446ff9d45b2..a4513bd2f4785c923c1670d05e696d6ec5503849 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -22,14 +22,14 @@ using Tensor = framework::Tensor; template __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, - T* scores, const float conf_thresh, - const int* anchors, const int h, const int w, + T* scores, const float conf_thresh, const int* anchors, + const int n, const int h, const int w, const int an_num, const int class_num, const int box_num, int input_size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; T box[4]; - for (; tid < box_num; tid += stride) { + for (; tid < n * box_num; tid += stride) { int grid_num = h * w; int i = tid / box_num; int j = (tid % box_num) / grid_num; @@ -99,12 +99,12 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { set_zero(dev_ctx, boxes, static_cast(0)); set_zero(dev_ctx, scores, static_cast(0)); - int grid_dim = (n * box_num + 4 - 1) / 4; - grid_dim = grid_dim > 2 ? 2 : grid_dim; + int grid_dim = (n * box_num + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; - KeYoloBoxFw<<>>( + KeYoloBoxFw<<>>( input_data, imgsize_data, boxes_data, scores_data, conf_thresh, - anchors_data, h, w, an_num, class_num, n * box_num, input_size); + anchors_data, n, h, w, an_num, class_num, box_num, input_size); } }; diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index e28c05e3e6de2438e4a67c34c4f087d3c1c5fc15..a1da4f64b6a5951a6543ecc82579d5aabe292ec6 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -99,11 +99,11 @@ class TestYoloBoxOp(OpTest): def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23] an_num = int(len(self.anchors) // 2) - self.batch_size = 1 + self.batch_size = 32 self.class_num = 2 self.conf_thresh = 0.5 self.downsample = 32 - self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 2, 2) + self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.imgsize_shape = (self.batch_size, 2)