提交 72a18bb1 编写于 作者: D dengkaipeng

add bbox range limit. test=develop

上级 fb863b48
...@@ -52,7 +52,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, ...@@ -52,7 +52,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx, GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width); grid_num, img_height, img_width);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4; box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx); CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width);
int label_idx = int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
......
...@@ -46,12 +46,17 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, ...@@ -46,12 +46,17 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
} }
template <typename T> template <typename T>
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
const int box_idx) { const int img_height, const int img_width) {
boxes[box_idx] = box[0] - box[2] / 2; boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2; boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2; boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2; boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] = boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 ? boxes[box_idx + 2] : static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 ? boxes[box_idx + 3] : static_cast<T>(img_height - 1);
} }
template <typename T> template <typename T>
...@@ -118,7 +123,7 @@ class YoloBoxKernel : public framework::OpKernel<T> { ...@@ -118,7 +123,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
GetYoloBox<T>(box, input_data, anchors_, l, k, j, h, input_size, GetYoloBox<T>(box, input_data, anchors_, l, k, j, h, input_size,
box_idx, stride, img_height, img_width); box_idx, stride, img_height, img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4; box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx); CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width);
int label_idx = int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册