From 72a18bb16028a1c54a038eed15887b384d25f42a Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Tue, 26 Feb 2019 02:26:23 +0000 Subject: [PATCH] add bbox range limit. test=develop --- paddle/fluid/operators/detection/yolo_box_op.cu | 2 +- paddle/fluid/operators/detection/yolo_box_op.h | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index a4513bd2f47..c9b5a19f82f 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -52,7 +52,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, GetYoloBox(box, input, anchors, l, k, j, h, input_size, box_idx, grid_num, img_height, img_width); box_idx = (i * box_num + j * grid_num + k * w + l) * 4; - CalcDetectionBox(boxes, box, box_idx); + CalcDetectionBox(boxes, box, box_idx, img_height, img_width); int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index 6188c5f32b7..cf028a6e06f 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -46,12 +46,17 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, } template -HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, - const int box_idx) { +HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx, + const int img_height, const int img_width) { boxes[box_idx] = box[0] - box[2] / 2; boxes[box_idx + 1] = box[1] - box[3] / 2; boxes[box_idx + 2] = box[0] + box[2] / 2; boxes[box_idx + 3] = box[1] + box[3] / 2; + + boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); + boxes[box_idx + 1] = boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); + boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 ? boxes[box_idx + 2] : static_cast(img_width - 1); + boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 ? boxes[box_idx + 3] : static_cast(img_height - 1); } template @@ -118,7 +123,7 @@ class YoloBoxKernel : public framework::OpKernel { GetYoloBox(box, input_data, anchors_, l, k, j, h, input_size, box_idx, stride, img_height, img_width); box_idx = (i * box_num + j * stride + k * w + l) * 4; - CalcDetectionBox(boxes_data, box, box_idx); + CalcDetectionBox(boxes_data, box, box_idx, img_height, img_width); int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); -- GitLab