未验证 提交 943a4449 编写于 作者: K Kaipeng Deng 提交者: GitHub

yolo_box OP add Attr(clip_bbox). (#21620)

* yolo_box OP add Attr(clip_bbox). test=develop
上级 c4f8f3bd
......@@ -43,9 +43,12 @@ class YoloBoxOp : public framework::OperatorWithKernel {
"+ class_num)).");
PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2,
"Input(ImgSize) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
dim_imgsize[0], dim_x[0],
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same.");
if ((dim_imgsize[0] > 0 && dim_x[0] > 0) || ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
dim_imgsize[0], dim_x[0],
platform::errors::InvalidArgument(
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same."));
}
PADDLE_ENFORCE_EQ(dim_imgsize[1], 2, "Input(ImgSize) dim[1] should be 2.");
PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater than 0.");
......@@ -110,6 +113,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"Boxes with confidence scores under threshold should "
"be ignored.")
.SetDefault(0.01);
AddAttr<bool>("clip_bbox",
"Whether clip output bonding box in Input(ImgSize) "
"boundary. Default true.")
.SetDefault(true);
AddComment(R"DOC(
This operator generates YOLO detection boxes from output of YOLOv3 network.
......
......@@ -26,7 +26,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh,
const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num,
const int box_num, int input_size) {
const int box_num, int input_size, bool clip_bbox) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
......@@ -53,7 +53,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width);
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
......@@ -76,6 +76,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
const int n = input->dims()[0];
const int h = input->dims()[2];
......@@ -107,7 +108,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size);
anchors_data, n, h, w, an_num, class_num, box_num, input_size,
clip_bbox);
}
};
......
......@@ -47,21 +47,23 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
template <typename T>
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
const int img_height,
const int img_width) {
const int img_width, bool clip_bbox) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<T>(img_height - 1);
if (clip_bbox) {
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<T>(img_height - 1);
}
}
template <typename T>
......@@ -86,6 +88,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
const int n = input->dims()[0];
const int h = input->dims()[2];
......@@ -130,8 +133,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
box_idx, stride, img_height, img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height,
img_width);
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width,
clip_bbox);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
......
......@@ -1023,6 +1023,7 @@ def yolo_box(x,
class_num,
conf_thresh,
downsample_ratio,
clip_bbox=True,
name=None):
"""
${comment}
......@@ -1034,6 +1035,7 @@ def yolo_box(x,
class_num (int): ${class_num_comment}
conf_thresh (float): ${conf_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment}
clip_bbox (bool): ${clip_bbox_comment}
name (string): The default value is None. Normally there is no need
for user to set this property. For more information,
please refer to :ref:`api_guide_Name`
......@@ -1081,6 +1083,7 @@ def yolo_box(x,
"class_num": class_num,
"conf_thresh": conf_thresh,
"downsample_ratio": downsample_ratio,
"clip_bbox": clip_bbox,
}
helper.append_op(
......
......@@ -32,6 +32,7 @@ def YoloBox(x, img_size, attrs):
class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample']
clip_bbox = attrs['clip_bbox']
input_size = downsample * h
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
......@@ -64,13 +65,14 @@ def YoloBox(x, img_size, attrs):
pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]
for i in range(len(pred_box)):
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf,
img_size[i, 1] - 1)
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf,
img_size[i, 0] - 1)
if clip_bbox:
for i in range(len(pred_box)):
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf,
img_size[i, 1] - 1)
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf,
img_size[i, 0] - 1)
return pred_box, pred_score.reshape((n, -1, class_num))
......@@ -87,6 +89,7 @@ class TestYoloBoxOp(OpTest):
"class_num": self.class_num,
"conf_thresh": self.conf_thresh,
"downsample": self.downsample,
"clip_bbox": self.clip_bbox,
}
self.inputs = {
......@@ -109,6 +112,20 @@ class TestYoloBoxOp(OpTest):
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = False
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册