提交 7808f4c0 编写于 作者: D dengkaipeng

fix unittest for yolo_box_op. test=develop

上级 cb2dca53
...@@ -36,7 +36,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, ...@@ -36,7 +36,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int k = (tid % grid_num) / w; int k = (tid % grid_num) / w;
int l = tid % w; int l = tid % w;
int an_stride = an_num * grid_num; int an_stride = (5 + class_num) * grid_num;
int img_height = imgsize[2 * i]; int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1]; int img_width = imgsize[2 * i + 1];
...@@ -56,7 +56,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, ...@@ -56,7 +56,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int label_idx = int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num; int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf, CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num); grid_num);
} }
...@@ -99,12 +99,12 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -99,12 +99,12 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, boxes, static_cast<T>(0)); set_zero(dev_ctx, boxes, static_cast<T>(0));
set_zero(dev_ctx, scores, static_cast<T>(0)); set_zero(dev_ctx, scores, static_cast<T>(0));
int grid_dim = (n * box_num + 512 - 1) / 512; int grid_dim = (n * box_num + 4 - 1) / 4;
grid_dim = grid_dim > 8 ? 8 : grid_dim; grid_dim = grid_dim > 2 ? 2 : grid_dim;
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( KeYoloBoxFw<T><<<grid_dim, 4, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh, input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, h, w, an_num, class_num, box_num, input_size); anchors_data, h, w, an_num, class_num, n * box_num, input_size);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册