diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index c9b5a19f82fef1b0ac6fc0428a563cf01cc8ccc3..a0c60ae673fc23ecb41a0badb9c62207f734a5c7 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -22,9 +22,9 @@ using Tensor = framework::Tensor; template __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, - T* scores, const float conf_thresh, const int* anchors, - const int n, const int h, const int w, - const int an_num, const int class_num, + T* scores, const float conf_thresh, + const int* anchors, const int n, const int h, + const int w, const int an_num, const int class_num, const int box_num, int input_size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; @@ -50,7 +50,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); GetYoloBox(box, input, anchors, l, k, j, h, input_size, box_idx, - grid_num, img_height, img_width); + grid_num, img_height, img_width); box_idx = (i * box_num + j * grid_num + k * w + l) * 4; CalcDetectionBox(boxes, box, box_idx, img_height, img_width); @@ -84,7 +84,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { int input_size = downsample_ratio * h; Tensor anchors_t, cpu_anchors_t; - auto cpu_anchors_data = cpu_anchors_t.mutable_data({an_num*2}, platform::CPUPlace()); + auto cpu_anchors_data = + cpu_anchors_t.mutable_data({an_num * 2}, platform::CPUPlace()); std::copy(anchors.begin(), anchors.end(), cpu_anchors_data); TensorCopySync(cpu_anchors_t, ctx.GetPlace(), &anchors_t); auto anchors_data = anchors_t.data(); @@ -103,8 +104,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { grid_dim = grid_dim > 8 ? 8 : grid_dim; KeYoloBoxFw<<>>( - input_data, imgsize_data, boxes_data, scores_data, conf_thresh, - anchors_data, n, h, w, an_num, class_num, box_num, input_size); + input_data, imgsize_data, boxes_data, scores_data, conf_thresh, + anchors_data, n, h, w, an_num, class_num, box_num, input_size); } }; @@ -112,6 +113,5 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(yolo_box, - ops::YoloBoxOpCUDAKernel, +REGISTER_OP_CUDA_KERNEL(yolo_box, ops::YoloBoxOpCUDAKernel, ops::YoloBoxOpCUDAKernel); diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index cf028a6e06f3d95cdca73915583649c56136fec3..546a5a66b4415829bd9b04bd7d4b7ced0ab9876e 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -20,7 +20,6 @@ namespace operators { using Tensor = framework::Tensor; - template HOSTDEVICE inline T sigmoid(T x) { return 1.0 / (1.0 + std::exp(-x)); @@ -28,15 +27,15 @@ HOSTDEVICE inline T sigmoid(T x) { template HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, - int j, int an_idx, int grid_size, - int input_size, int index, int stride, - int img_height, int img_width) { + int j, int an_idx, int grid_size, + int input_size, int index, int stride, + int img_height, int img_width) { box[0] = (i + sigmoid(x[index])) * img_width / grid_size; box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size; box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / - input_size; - box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height / - input_size; + input_size; + box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * + img_height / input_size; } HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, @@ -47,16 +46,22 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, template HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx, - const int img_height, const int img_width) { + const int img_height, + const int img_width) { boxes[box_idx] = box[0] - box[2] / 2; boxes[box_idx + 1] = box[1] - box[3] / 2; boxes[box_idx + 2] = box[0] + box[2] / 2; boxes[box_idx + 3] = box[1] + box[3] / 2; - + boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); - boxes[box_idx + 1] = boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); - boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 ? boxes[box_idx + 2] : static_cast(img_width - 1); - boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 ? boxes[box_idx + 3] : static_cast(img_height - 1); + boxes[box_idx + 1] = + boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); + boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 + ? boxes[box_idx + 2] + : static_cast(img_width - 1); + boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 + ? boxes[box_idx + 3] + : static_cast(img_height - 1); } template @@ -92,8 +97,10 @@ class YoloBoxKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; - int anchors_[anchors.size()]; - std::copy(anchors.begin(), anchors.end(), anchors_); + Tensor anchors_; + auto anchors_data = + anchors_.mutable_data({an_num * 2}, ctx.GetPlace()); + std::copy(anchors.begin(), anchors.end(), anchors_data); const T* input_data = input->data(); const int* imgsize_data = imgsize->data(); @@ -120,10 +127,11 @@ class YoloBoxKernel : public framework::OpKernel { int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); - GetYoloBox(box, input_data, anchors_, l, k, j, h, input_size, - box_idx, stride, img_height, img_width); + GetYoloBox(box, input_data, anchors_data, l, k, j, h, input_size, + box_idx, stride, img_height, img_width); box_idx = (i * box_num + j * stride + k * w + l) * 4; - CalcDetectionBox(boxes_data, box, box_idx, img_height, img_width); + CalcDetectionBox(boxes_data, box, box_idx, img_height, + img_width); int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index a1da4f64b6a5951a6543ecc82579d5aabe292ec6..d4a179794c261b6e41e049c16c60eb852bb3e50f 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -59,12 +59,19 @@ def YoloBox(x, img_size, attrs): pred_box[:, :, :2], pred_box[:, :, 2:4] = \ pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \ pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0 - # pred_box = pred_box * input_size pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis] pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis] pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis] pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis] + for i in range(len(pred_box)): + pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf) + pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf) + pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf, + img_size[i, 1] - 1) + pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf, + img_size[i, 0] - 1) + return pred_box, pred_score.reshape((n, -1, class_num)) @@ -93,8 +100,7 @@ class TestYoloBoxOp(OpTest): } def test_check_output(self): - place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) + self.check_output() def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23]