提交 7808f4c0 编写于 作者: D dengkaipeng

fix unittest for yolo_box_op. test=develop

上级 cb2dca53
......@@ -36,7 +36,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int k = (tid % grid_num) / w;
int l = tid % w;
int an_stride = an_num * grid_num;
int an_stride = (5 + class_num) * grid_num;
int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1];
......@@ -56,7 +56,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num);
}
......@@ -99,12 +99,12 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, boxes, static_cast<T>(0));
set_zero(dev_ctx, scores, static_cast<T>(0));
int grid_dim = (n * box_num + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
int grid_dim = (n * box_num + 4 - 1) / 4;
grid_dim = grid_dim > 2 ? 2 : grid_dim;
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
KeYoloBoxFw<T><<<grid_dim, 4, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, h, w, an_num, class_num, box_num, input_size);
anchors_data, h, w, an_num, class_num, n * box_num, input_size);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册