diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index 9fdf39a7a0b8ac0c3f0873eb3f0f699721469120..01edf7b41b2a8dc6cce65a6e31d89025e884e58d 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -26,8 +26,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, T* scores, const float conf_thresh, const int* anchors, const int n, const int h, const int w, const int an_num, const int class_num, - const int box_num, int input_size, bool clip_bbox, - const float scale, const float bias) { + const int box_num, int input_size_h, + int input_size_w, bool clip_bbox, const float scale, + const float bias) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; T box[4]; @@ -51,8 +52,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); - GetYoloBox(box, input, anchors, l, k, j, h, input_size, box_idx, - grid_num, img_height, img_width, scale, bias); + GetYoloBox(box, input, anchors, l, k, j, h, w, input_size_h, + input_size_w, box_idx, grid_num, img_height, img_width, scale, + bias); box_idx = (i * box_num + j * grid_num + k * w + l) * 4; CalcDetectionBox(boxes, box, box_idx, img_height, img_width, clip_bbox); @@ -86,7 +88,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { const int w = input->dims()[3]; const int box_num = boxes->dims()[1]; const int an_num = anchors.size() / 2; - int input_size = downsample_ratio * h; + int input_size_h = downsample_ratio * h; + int input_size_w = downsample_ratio * w; auto& dev_ctx = ctx.cuda_device_context(); int bytes = sizeof(int) * anchors.size(); @@ -111,8 +114,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { KeYoloBoxFw<<>>( input_data, imgsize_data, boxes_data, scores_data, conf_thresh, - anchors_data, n, h, w, an_num, class_num, box_num, input_size, - clip_bbox, scale, bias); + anchors_data, n, h, w, an_num, class_num, box_num, input_size_h, + input_size_w, clip_bbox, scale, bias); } }; diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index 388467d37ba644abe651e1080ce14dd0d0e704bc..1cfef142bca7327cb039412719b7c002beb53cab 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -27,17 +27,18 @@ HOSTDEVICE inline T sigmoid(T x) { template HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, - int j, int an_idx, int grid_size, - int input_size, int index, int stride, + int j, int an_idx, int grid_size_h, + int grid_size_w, int input_size_h, + int input_size_w, int index, int stride, int img_height, int img_width, float scale, float bias) { - box[0] = (i + sigmoid(x[index]) * scale + bias) * img_width / grid_size; + box[0] = (i + sigmoid(x[index]) * scale + bias) * img_width / grid_size_w; box[1] = (j + sigmoid(x[index + stride]) * scale + bias) * img_height / - grid_size; + grid_size_h; box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / - input_size; + input_size_w; box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * - img_height / input_size; + img_height / input_size_h; } HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, @@ -99,7 +100,8 @@ class YoloBoxKernel : public framework::OpKernel { const int w = input->dims()[3]; const int box_num = boxes->dims()[1]; const int an_num = anchors.size() / 2; - int input_size = downsample_ratio * h; + int input_size_h = downsample_ratio * h; + int input_size_w = downsample_ratio * w; const int stride = h * w; const int an_stride = (class_num + 5) * stride; @@ -134,8 +136,9 @@ class YoloBoxKernel : public framework::OpKernel { int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); - GetYoloBox(box, input_data, anchors_data, l, k, j, h, input_size, - box_idx, stride, img_height, img_width, scale, bias); + GetYoloBox(box, input_data, anchors_data, l, k, j, h, w, + input_size_h, input_size_w, box_idx, stride, + img_height, img_width, scale, bias); box_idx = (i * box_num + j * stride + k * w + l) * 4; CalcDetectionBox(boxes_data, box, box_idx, img_height, img_width, clip_bbox);