提交 6bfa339c 编写于 作者: D dengkaipeng

update yolo_box support h != w. test=develop

上级 264e76ca
......@@ -26,8 +26,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh,
const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num,
const int box_num, int input_size, bool clip_bbox,
const float scale, const float bias) {
const int box_num, int input_size_h,
int input_size_w, bool clip_bbox, const float scale,
const float bias) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
......@@ -51,8 +52,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width, scale, bias);
GetYoloBox<T>(box, input, anchors, l, k, j, h, w, input_size_h,
input_size_w, box_idx, grid_num, img_height, img_width, scale,
bias);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
......@@ -86,7 +88,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
int input_size_h = downsample_ratio * h;
int input_size_w = downsample_ratio * w;
auto& dev_ctx = ctx.cuda_device_context();
int bytes = sizeof(int) * anchors.size();
......@@ -111,8 +114,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size,
clip_bbox, scale, bias);
anchors_data, n, h, w, an_num, class_num, box_num, input_size_h,
input_size_w, clip_bbox, scale, bias);
}
};
......
......@@ -27,17 +27,18 @@ HOSTDEVICE inline T sigmoid(T x) {
template <typename T>
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride,
int j, int an_idx, int grid_size_h,
int grid_size_w, int input_size_h,
int input_size_w, int index, int stride,
int img_height, int img_width, float scale,
float bias) {
box[0] = (i + sigmoid<T>(x[index]) * scale + bias) * img_width / grid_size;
box[0] = (i + sigmoid<T>(x[index]) * scale + bias) * img_width / grid_size_w;
box[1] = (j + sigmoid<T>(x[index + stride]) * scale + bias) * img_height /
grid_size;
grid_size_h;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
input_size_w;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size;
img_height / input_size_h;
}
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
......@@ -99,7 +100,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
int input_size_h = downsample_ratio * h;
int input_size_w = downsample_ratio * w;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
......@@ -134,8 +136,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
box_idx, stride, img_height, img_width, scale, bias);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, w,
input_size_h, input_size_w, box_idx, stride,
img_height, img_width, scale, bias);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width,
clip_bbox);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册