From b154470ccd8f24e90322b2ed6e814e8684b4e01b Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Wed, 9 Jun 2021 10:36:32 +0800 Subject: [PATCH] add two attributes for yolo box (#33400) * add two attributes for yolo box --- .../fluid/operators/detection/yolo_box_op.cc | 67 +++++++++++++--- .../fluid/operators/detection/yolo_box_op.cu | 25 ++++-- .../fluid/operators/detection/yolo_box_op.h | 37 +++++++-- python/paddle/fluid/layers/detection.py | 8 +- .../fluid/tests/unittests/test_yolo_box_op.py | 77 +++++++++++++++++-- python/paddle/vision/ops.py | 26 +++++-- 6 files changed, 202 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index 6f2a3ca8762..e6f6c2a3935 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -11,6 +11,7 @@ #include "paddle/fluid/operators/detection/yolo_box_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -31,19 +32,44 @@ class YoloBoxOp : public framework::OperatorWithKernel { auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; auto class_num = ctx->Attrs().Get("class_num"); + auto iou_aware = ctx->Attrs().Get("iou_aware"); + auto iou_aware_factor = ctx->Attrs().Get("iou_aware_factor"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, platform::errors::InvalidArgument( "Input(X) should be a 4-D tensor." "But received X dimension(%s)", dim_x.size())); - PADDLE_ENFORCE_EQ( - dim_x[1], anchor_num * (5 + class_num), - platform::errors::InvalidArgument( - "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " - "+ class_num))." - "But received dim[1](%s) != (anchor_mask_number * " - "(5+class_num)(%s).", - dim_x[1], anchor_num * (5 + class_num))); + if (iou_aware) { + PADDLE_ENFORCE_EQ( + dim_x[1], anchor_num * (6 + class_num), + platform::errors::InvalidArgument( + "Input(X) dim[1] should be equal to (anchor_mask_number * (6 " + "+ class_num)) while iou_aware is true." + "But received dim[1](%s) != (anchor_mask_number * " + "(6+class_num)(%s).", + dim_x[1], anchor_num * (6 + class_num))); + PADDLE_ENFORCE_GE( + iou_aware_factor, 0, + platform::errors::InvalidArgument( + "Attr(iou_aware_factor) should greater than or equal to 0." + "But received iou_aware_factor (%s)", + iou_aware_factor)); + PADDLE_ENFORCE_LE( + iou_aware_factor, 1, + platform::errors::InvalidArgument( + "Attr(iou_aware_factor) should less than or equal to 1." + "But received iou_aware_factor (%s)", + iou_aware_factor)); + } else { + PADDLE_ENFORCE_EQ( + dim_x[1], anchor_num * (5 + class_num), + platform::errors::InvalidArgument( + "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " + "+ class_num))." + "But received dim[1](%s) != (anchor_mask_number * " + "(5+class_num)(%s).", + dim_x[1], anchor_num * (5 + class_num))); + } PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2, platform::errors::InvalidArgument( "Input(ImgSize) should be a 2-D tensor." @@ -140,6 +166,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { "Scale the center point of decoded bounding " "box. Default 1.0") .SetDefault(1.); + AddAttr("iou_aware", "Whether use iou aware. Default false.") + .SetDefault(false); + AddAttr("iou_aware_factor", "iou aware factor. Default 0.5.") + .SetDefault(0.5); AddComment(R"DOC( This operator generates YOLO detection boxes from output of YOLOv3 network. @@ -147,7 +177,8 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { should be the same, H and W specify the grid size, each grid point predict given number boxes, this given number, which following will be represented as S, is specified by the number of anchors. In the second dimension(the channel - dimension), C should be equal to S * (5 + class_num), class_num is the object + dimension), C should be equal to S * (5 + class_num) if :attr:`iou_aware` is false, + otherwise C should be equal to S * (6 + class_num). class_num is the object category number of source dataset(such as 80 in coco dataset), so the second(channel) dimension, apart from 4 box location coordinates x, y, w, h, also includes confidence score of the box and class one-hot key of each anchor @@ -183,6 +214,15 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { score_{pred} = score_{conf} * score_{class} $$ + where the confidence scores follow the formula bellow + + .. math:: + + score_{conf} = \begin{case} + obj, \text{if } iou_aware == flase \\ + obj^{1 - iou_aware_factor} * iou^{iou_aware_factor}, \text{otherwise} + \end{case} + )DOC"); } }; @@ -197,3 +237,12 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel, ops::YoloBoxKernel); + +REGISTER_OP_VERSION(yolo_box) + .AddCheckpoint( + R"ROC( + Upgrade yolo box to add new attribute [iou_aware, iou_aware_factor]. + )ROC", + paddle::framework::compatible::OpVersionDesc() + .NewAttr("iou_aware", "Whether use iou aware", false) + .NewAttr("iou_aware_factor", "iou aware factor", 0.5f)); diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index 65dc73ef383..ef0b870ebfd 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -28,7 +28,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, const int w, const int an_num, const int class_num, const int box_num, int input_size_h, int input_size_w, bool clip_bbox, const float scale, - const float bias) { + const float bias, bool iou_aware, + const float iou_aware_factor) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; T box[4]; @@ -43,23 +44,29 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int img_height = imgsize[2 * i]; int img_width = imgsize[2 * i + 1]; - int obj_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4); + int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4, + iou_aware); T conf = sigmoid(input[obj_idx]); + if (iou_aware) { + int iou_idx = GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num); + T iou = sigmoid(input[iou_idx]); + conf = pow(conf, static_cast(1. - iou_aware_factor)) * + pow(iou, static_cast(iou_aware_factor)); + } if (conf < conf_thresh) { continue; } - int box_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); + int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0, + iou_aware); GetYoloBox(box, input, anchors, l, k, j, h, w, input_size_h, input_size_w, box_idx, grid_num, img_height, img_width, scale, bias); box_idx = (i * box_num + j * grid_num + k * w + l) * 4; CalcDetectionBox(boxes, box, box_idx, img_height, img_width, clip_bbox); - int label_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); + int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, + 5, iou_aware); int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; CalcLabelScore(scores, input, label_idx, score_idx, class_num, conf, grid_num); @@ -80,6 +87,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { float conf_thresh = ctx.Attr("conf_thresh"); int downsample_ratio = ctx.Attr("downsample_ratio"); bool clip_bbox = ctx.Attr("clip_bbox"); + bool iou_aware = ctx.Attr("iou_aware"); + float iou_aware_factor = ctx.Attr("iou_aware_factor"); float scale = ctx.Attr("scale_x_y"); float bias = -0.5 * (scale - 1.); @@ -115,7 +124,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { ctx.cuda_device_context().stream()>>>( input_data, imgsize_data, boxes_data, scores_data, conf_thresh, anchors_data, n, h, w, an_num, class_num, box_num, input_size_h, - input_size_w, clip_bbox, scale, bias); + input_size_w, clip_bbox, scale, bias, iou_aware, iou_aware_factor); } }; diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index 1cfef142bca..e06c81052a0 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -13,6 +13,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { @@ -43,8 +44,19 @@ HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num, int an_stride, int stride, - int entry) { - return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; + int entry, bool iou_aware) { + if (iou_aware) { + return (batch * an_num + an_idx) * an_stride + + (batch * an_num + an_num + entry) * stride + hw_idx; + } else { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; + } +} + +HOSTDEVICE inline int GetIoUIndex(int batch, int an_idx, int hw_idx, int an_num, + int an_stride, int stride) { + return batch * an_num * an_stride + (batch * an_num + an_idx) * stride + + hw_idx; } template @@ -92,6 +104,8 @@ class YoloBoxKernel : public framework::OpKernel { float conf_thresh = ctx.Attr("conf_thresh"); int downsample_ratio = ctx.Attr("downsample_ratio"); bool clip_bbox = ctx.Attr("clip_bbox"); + bool iou_aware = ctx.Attr("iou_aware"); + float iou_aware_factor = ctx.Attr("iou_aware_factor"); float scale = ctx.Attr("scale_x_y"); float bias = -0.5 * (scale - 1.); @@ -127,15 +141,22 @@ class YoloBoxKernel : public framework::OpKernel { for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - int obj_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4); + int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, + stride, 4, iou_aware); T conf = sigmoid(input_data[obj_idx]); + if (iou_aware) { + int iou_idx = + GetIoUIndex(i, j, k * w + l, an_num, an_stride, stride); + T iou = sigmoid(input_data[iou_idx]); + conf = pow(conf, static_cast(1. - iou_aware_factor)) * + pow(iou, static_cast(iou_aware_factor)); + } if (conf < conf_thresh) { continue; } - int box_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); + int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, + stride, 0, iou_aware); GetYoloBox(box, input_data, anchors_data, l, k, j, h, w, input_size_h, input_size_w, box_idx, stride, img_height, img_width, scale, bias); @@ -143,8 +164,8 @@ class YoloBoxKernel : public framework::OpKernel { CalcDetectionBox(boxes_data, box, box_idx, img_height, img_width, clip_bbox); - int label_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); + int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, + stride, 5, iou_aware); int score_idx = (i * box_num + j * stride + k * w + l) * class_num; CalcLabelScore(scores_data, input_data, label_idx, score_idx, class_num, conf, stride); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index cf4abc207bd..604bcc0e277 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -1139,7 +1139,9 @@ def yolo_box(x, downsample_ratio, clip_bbox=True, name=None, - scale_x_y=1.): + scale_x_y=1., + iou_aware=False, + iou_aware_factor=0.5): """ ${comment} @@ -1156,6 +1158,8 @@ def yolo_box(x, name (string): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` + iou_aware (bool): ${iou_aware_comment} + iou_aware_factor (float): ${iou_aware_factor_comment} Returns: Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, @@ -1204,6 +1208,8 @@ def yolo_box(x, "downsample_ratio": downsample_ratio, "clip_bbox": clip_bbox, "scale_x_y": scale_x_y, + "iou_aware": iou_aware, + "iou_aware_factor": iou_aware_factor } helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index 24c463ebfc9..5793f0148fc 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -35,10 +35,16 @@ def YoloBox(x, img_size, attrs): downsample = attrs['downsample'] clip_bbox = attrs['clip_bbox'] scale_x_y = attrs['scale_x_y'] + iou_aware = attrs['iou_aware'] + iou_aware_factor = attrs['iou_aware_factor'] bias_x_y = -0.5 * (scale_x_y - 1.) input_h = downsample * h input_w = downsample * w + if iou_aware: + ioup = x[:, :an_num, :, :] + ioup = np.expand_dims(ioup, axis=-1) + x = x[:, an_num:, :, :] x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) pred_box = x[:, :, :, :, :4].copy() @@ -57,7 +63,11 @@ def YoloBox(x, img_size, attrs): pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h - pred_conf = sigmoid(x[:, :, :, :, 4:5]) + if iou_aware: + pred_conf = sigmoid(x[:, :, :, :, 4:5])**( + 1 - iou_aware_factor) * sigmoid(ioup)**iou_aware_factor + else: + pred_conf = sigmoid(x[:, :, :, :, 4:5]) pred_conf[pred_conf < conf_thresh] = 0. pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf pred_box = pred_box * (pred_conf > 0.).astype('float32') @@ -97,6 +107,8 @@ class TestYoloBoxOp(OpTest): "downsample": self.downsample, "clip_bbox": self.clip_bbox, "scale_x_y": self.scale_x_y, + "iou_aware": self.iou_aware, + "iou_aware_factor": self.iou_aware_factor } self.inputs = { @@ -123,6 +135,8 @@ class TestYoloBoxOp(OpTest): self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.imgsize_shape = (self.batch_size, 2) self.scale_x_y = 1. + self.iou_aware = False + self.iou_aware_factor = 0.5 class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): @@ -137,6 +151,8 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.imgsize_shape = (self.batch_size, 2) self.scale_x_y = 1. + self.iou_aware = False + self.iou_aware_factor = 0.5 class TestYoloBoxOpScaleXY(TestYoloBoxOp): @@ -151,19 +167,36 @@ class TestYoloBoxOpScaleXY(TestYoloBoxOp): self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.imgsize_shape = (self.batch_size, 2) self.scale_x_y = 1.2 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +class TestYoloBoxOpIoUAware(TestYoloBoxOp): + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int(len(self.anchors) // 2) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = True + self.x_shape = (self.batch_size, an_num * (6 + self.class_num), 13, 13) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1. + self.iou_aware = True + self.iou_aware_factor = 0.5 class TestYoloBoxDygraph(unittest.TestCase): def test_dygraph(self): paddle.disable_static() - x = np.random.random([2, 14, 8, 8]).astype('float32') img_size = np.ones((2, 2)).astype('int32') - - x = paddle.to_tensor(x) img_size = paddle.to_tensor(img_size) + x1 = np.random.random([2, 14, 8, 8]).astype('float32') + x1 = paddle.to_tensor(x1) boxes, scores = paddle.vision.ops.yolo_box( - x, + x1, img_size=img_size, anchors=[10, 13, 16, 30], class_num=2, @@ -172,16 +205,30 @@ class TestYoloBoxDygraph(unittest.TestCase): clip_bbox=True, scale_x_y=1.) assert boxes is not None and scores is not None + + x2 = np.random.random([2, 16, 8, 8]).astype('float32') + x2 = paddle.to_tensor(x2) + boxes, scores = paddle.vision.ops.yolo_box( + x2, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1., + iou_aware=True, + iou_aware_factor=0.5) paddle.enable_static() class TestYoloBoxStatic(unittest.TestCase): def test_static(self): - x = paddle.static.data('x', [2, 14, 8, 8], 'float32') + x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32') img_size = paddle.static.data('img_size', [2, 2], 'int32') boxes, scores = paddle.vision.ops.yolo_box( - x, + x1, img_size=img_size, anchors=[10, 13, 16, 30], class_num=2, @@ -191,6 +238,20 @@ class TestYoloBoxStatic(unittest.TestCase): scale_x_y=1.) assert boxes is not None and scores is not None + x2 = paddle.static.data('x2', [2, 16, 8, 8], 'float32') + boxes, scores = paddle.vision.ops.yolo_box( + x2, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1., + iou_aware=True, + iou_aware_factor=0.5) + assert boxes is not None and scores is not None + class TestYoloBoxOpHW(TestYoloBoxOp): def initTestCase(self): @@ -204,6 +265,8 @@ class TestYoloBoxOpHW(TestYoloBoxOp): self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 9) self.imgsize_shape = (self.batch_size, 2) self.scale_x_y = 1. + self.iou_aware = False + self.iou_aware_factor = 0.5 if __name__ == "__main__": diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 60a7a90c9be..769e33c7355 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -247,7 +247,9 @@ def yolo_box(x, downsample_ratio, clip_bbox=True, name=None, - scale_x_y=1.): + scale_x_y=1., + iou_aware=False, + iou_aware_factor=0.5): r""" This operator generates YOLO detection boxes from output of YOLOv3 network. @@ -256,7 +258,8 @@ def yolo_box(x, should be the same, H and W specify the grid size, each grid point predict given number boxes, this given number, which following will be represented as S, is specified by the number of anchors. In the second dimension(the channel - dimension), C should be equal to S * (5 + class_num), class_num is the object + dimension), C should be equal to S * (5 + class_num) if :attr:`iou_aware` is false, + otherwise C should be equal to S * (6 + class_num). class_num is the object category number of source dataset(such as 80 in coco dataset), so the second(channel) dimension, apart from 4 box location coordinates x, y, w, h, also includes confidence score of the box and class one-hot key of each anchor @@ -292,6 +295,15 @@ def yolo_box(x, score_{pred} = score_{conf} * score_{class} $$ + where the confidence scores follow the formula bellow + + .. math:: + + score_{conf} = \begin{case} + obj, \text{if } iou_aware == flase \\ + obj^{1 - iou_aware_factor} * iou^{iou_aware_factor}, \text{otherwise} + \end{case} + Args: x (Tensor): The input tensor of YoloBox operator is a 4-D tensor with shape of [N, C, H, W]. The second dimension(C) stores box @@ -313,13 +325,14 @@ def yolo_box(x, should be set for the first, second, and thrid :attr:`yolo_box` layer. clip_bbox (bool): Whether clip output bonding box in :attr:`img_size` - boundary. Default true." - " + boundary. Default true. scale_x_y (float): Scale the center point of decoded bounding box. Default 1.0 name (string): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` + iou_aware (bool): Whether use iou aware. Default false + iou_aware_factor (float): iou aware factor. Default 0.5 Returns: Tensor: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, @@ -358,7 +371,8 @@ def yolo_box(x, boxes, scores = core.ops.yolo_box( x, img_size, 'anchors', anchors, 'class_num', class_num, 'conf_thresh', conf_thresh, 'downsample_ratio', downsample_ratio, - 'clip_bbox', clip_bbox, 'scale_x_y', scale_x_y) + 'clip_bbox', clip_bbox, 'scale_x_y', scale_x_y, 'iou_aware', + iou_aware, 'iou_aware_factor', iou_aware_factor) return boxes, scores helper = LayerHelper('yolo_box', **locals()) @@ -378,6 +392,8 @@ def yolo_box(x, "downsample_ratio": downsample_ratio, "clip_bbox": clip_bbox, "scale_x_y": scale_x_y, + "iou_aware": iou_aware, + "iou_aware_factor": iou_aware_factor } helper.append_op( -- GitLab