提交 c9d4676b 编写于 作者: D dengkaipeng

fix multi batch idx error. test=develop

上级 7808f4c0
...@@ -22,14 +22,14 @@ using Tensor = framework::Tensor; ...@@ -22,14 +22,14 @@ using Tensor = framework::Tensor;
template <typename T> template <typename T>
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh, T* scores, const float conf_thresh, const int* anchors,
const int* anchors, const int h, const int w, const int n, const int h, const int w,
const int an_num, const int class_num, const int an_num, const int class_num,
const int box_num, int input_size) { const int box_num, int input_size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
T box[4]; T box[4];
for (; tid < box_num; tid += stride) { for (; tid < n * box_num; tid += stride) {
int grid_num = h * w; int grid_num = h * w;
int i = tid / box_num; int i = tid / box_num;
int j = (tid % box_num) / grid_num; int j = (tid % box_num) / grid_num;
...@@ -99,12 +99,12 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -99,12 +99,12 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, boxes, static_cast<T>(0)); set_zero(dev_ctx, boxes, static_cast<T>(0));
set_zero(dev_ctx, scores, static_cast<T>(0)); set_zero(dev_ctx, scores, static_cast<T>(0));
int grid_dim = (n * box_num + 4 - 1) / 4; int grid_dim = (n * box_num + 512 - 1) / 512;
grid_dim = grid_dim > 2 ? 2 : grid_dim; grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeYoloBoxFw<T><<<grid_dim, 4, 0, ctx.cuda_device_context().stream()>>>( KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh, input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, h, w, an_num, class_num, n * box_num, input_size); anchors_data, n, h, w, an_num, class_num, box_num, input_size);
} }
}; };
......
...@@ -99,11 +99,11 @@ class TestYoloBoxOp(OpTest): ...@@ -99,11 +99,11 @@ class TestYoloBoxOp(OpTest):
def initTestCase(self): def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23] self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2) an_num = int(len(self.anchors) // 2)
self.batch_size = 1 self.batch_size = 32
self.class_num = 2 self.class_num = 2
self.conf_thresh = 0.5 self.conf_thresh = 0.5
self.downsample = 32 self.downsample = 32
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 2, 2) self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2) self.imgsize_shape = (self.batch_size, 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册