未验证 提交 b6eff442 编写于 作者: K Kaipeng Deng 提交者: GitHub

update yolo_box support h != w. test=develop (#27327)

上级 c1eed1fa
...@@ -26,8 +26,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, ...@@ -26,8 +26,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh, T* scores, const float conf_thresh,
const int* anchors, const int n, const int h, const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num, const int w, const int an_num, const int class_num,
const int box_num, int input_size, bool clip_bbox, const int box_num, int input_size_h,
const float scale, const float bias) { int input_size_w, bool clip_bbox, const float scale,
const float bias) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
T box[4]; T box[4];
...@@ -51,8 +52,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, ...@@ -51,8 +52,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int box_idx = int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx, GetYoloBox<T>(box, input, anchors, l, k, j, h, w, input_size_h,
grid_num, img_height, img_width, scale, bias); input_size_w, box_idx, grid_num, img_height, img_width, scale,
bias);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4; box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox); CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
...@@ -86,7 +88,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -86,7 +88,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
const int w = input->dims()[3]; const int w = input->dims()[3];
const int box_num = boxes->dims()[1]; const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2; const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h; int input_size_h = downsample_ratio * h;
int input_size_w = downsample_ratio * w;
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
int bytes = sizeof(int) * anchors.size(); int bytes = sizeof(int) * anchors.size();
...@@ -111,8 +114,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -111,8 +114,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh, input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size, anchors_data, n, h, w, an_num, class_num, box_num, input_size_h,
clip_bbox, scale, bias); input_size_w, clip_bbox, scale, bias);
} }
}; };
......
...@@ -27,17 +27,18 @@ HOSTDEVICE inline T sigmoid(T x) { ...@@ -27,17 +27,18 @@ HOSTDEVICE inline T sigmoid(T x) {
template <typename T> template <typename T>
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
int j, int an_idx, int grid_size, int j, int an_idx, int grid_size_h,
int input_size, int index, int stride, int grid_size_w, int input_size_h,
int input_size_w, int index, int stride,
int img_height, int img_width, float scale, int img_height, int img_width, float scale,
float bias) { float bias) {
box[0] = (i + sigmoid<T>(x[index]) * scale + bias) * img_width / grid_size; box[0] = (i + sigmoid<T>(x[index]) * scale + bias) * img_width / grid_size_w;
box[1] = (j + sigmoid<T>(x[index + stride]) * scale + bias) * img_height / box[1] = (j + sigmoid<T>(x[index + stride]) * scale + bias) * img_height /
grid_size; grid_size_h;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size; input_size_w;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size; img_height / input_size_h;
} }
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
...@@ -99,7 +100,8 @@ class YoloBoxKernel : public framework::OpKernel<T> { ...@@ -99,7 +100,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
const int w = input->dims()[3]; const int w = input->dims()[3];
const int box_num = boxes->dims()[1]; const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2; const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h; int input_size_h = downsample_ratio * h;
int input_size_w = downsample_ratio * w;
const int stride = h * w; const int stride = h * w;
const int an_stride = (class_num + 5) * stride; const int an_stride = (class_num + 5) * stride;
...@@ -134,8 +136,9 @@ class YoloBoxKernel : public framework::OpKernel<T> { ...@@ -134,8 +136,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int box_idx = int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size, GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, w,
box_idx, stride, img_height, img_width, scale, bias); input_size_h, input_size_w, box_idx, stride,
img_height, img_width, scale, bias);
box_idx = (i * box_num + j * stride + k * w + l) * 4; box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width, CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width,
clip_bbox); clip_bbox);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册