diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index a4513bd2f4785c923c1670d05e696d6ec5503849..c9b5a19f82fef1b0ac6fc0428a563cf01cc8ccc3 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -52,7 +52,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, GetYoloBox(box, input, anchors, l, k, j, h, input_size, box_idx, grid_num, img_height, img_width); box_idx = (i * box_num + j * grid_num + k * w + l) * 4; - CalcDetectionBox(boxes, box, box_idx); + CalcDetectionBox(boxes, box, box_idx, img_height, img_width); int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index 6188c5f32b742e3ff62ec3247e2a0a48149f29e6..cf028a6e06f3d95cdca73915583649c56136fec3 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -46,12 +46,17 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, } template -HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, - const int box_idx) { +HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx, + const int img_height, const int img_width) { boxes[box_idx] = box[0] - box[2] / 2; boxes[box_idx + 1] = box[1] - box[3] / 2; boxes[box_idx + 2] = box[0] + box[2] / 2; boxes[box_idx + 3] = box[1] + box[3] / 2; + + boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); + boxes[box_idx + 1] = boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); + boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 ? boxes[box_idx + 2] : static_cast(img_width - 1); + boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 ? boxes[box_idx + 3] : static_cast(img_height - 1); } template @@ -118,7 +123,7 @@ class YoloBoxKernel : public framework::OpKernel { GetYoloBox(box, input_data, anchors_, l, k, j, h, input_size, box_idx, stride, img_height, img_width); box_idx = (i * box_num + j * stride + k * w + l) * 4; - CalcDetectionBox(boxes_data, box, box_idx); + CalcDetectionBox(boxes_data, box, box_idx, img_height, img_width); int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);