diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index 6f2a3ca87623847f261f0111bdfd8c168bb24b0a..e6f6c2a39358fdc94b36bd1aa2afd2e5d0a495c6 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -11,6 +11,7 @@ #include "paddle/fluid/operators/detection/yolo_box_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -31,19 +32,44 @@ class YoloBoxOp : public framework::OperatorWithKernel { auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; auto class_num = ctx->Attrs().Get("class_num"); + auto iou_aware = ctx->Attrs().Get("iou_aware"); + auto iou_aware_factor = ctx->Attrs().Get("iou_aware_factor"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, platform::errors::InvalidArgument( "Input(X) should be a 4-D tensor." "But received X dimension(%s)", dim_x.size())); - PADDLE_ENFORCE_EQ( - dim_x[1], anchor_num * (5 + class_num), - platform::errors::InvalidArgument( - "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " - "+ class_num))." - "But received dim[1](%s) != (anchor_mask_number * " - "(5+class_num)(%s).", - dim_x[1], anchor_num * (5 + class_num))); + if (iou_aware) { + PADDLE_ENFORCE_EQ( + dim_x[1], anchor_num * (6 + class_num), + platform::errors::InvalidArgument( + "Input(X) dim[1] should be equal to (anchor_mask_number * (6 " + "+ class_num)) while iou_aware is true." + "But received dim[1](%s) != (anchor_mask_number * " + "(6+class_num)(%s).", + dim_x[1], anchor_num * (6 + class_num))); + PADDLE_ENFORCE_GE( + iou_aware_factor, 0, + platform::errors::InvalidArgument( + "Attr(iou_aware_factor) should greater than or equal to 0." + "But received iou_aware_factor (%s)", + iou_aware_factor)); + PADDLE_ENFORCE_LE( + iou_aware_factor, 1, + platform::errors::InvalidArgument( + "Attr(iou_aware_factor) should less than or equal to 1." + "But received iou_aware_factor (%s)", + iou_aware_factor)); + } else { + PADDLE_ENFORCE_EQ( + dim_x[1], anchor_num * (5 + class_num), + platform::errors::InvalidArgument( + "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " + "+ class_num))." + "But received dim[1](%s) != (anchor_mask_number * " + "(5+class_num)(%s).", + dim_x[1], anchor_num * (5 + class_num))); + } PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2, platform::errors::InvalidArgument( "Input(ImgSize) should be a 2-D tensor." @@ -140,6 +166,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { "Scale the center point of decoded bounding " "box. Default 1.0") .SetDefault(1.); + AddAttr("iou_aware", "Whether use iou aware. Default false.") + .SetDefault(false); + AddAttr("iou_aware_factor", "iou aware factor. Default 0.5.") + .SetDefault(0.5); AddComment(R"DOC( This operator generates YOLO detection boxes from output of YOLOv3 network. @@ -147,7 +177,8 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { should be the same, H and W specify the grid size, each grid point predict given number boxes, this given number, which following will be represented as S, is specified by the number of anchors. In the second dimension(the channel - dimension), C should be equal to S * (5 + class_num), class_num is the object + dimension), C should be equal to S * (5 + class_num) if :attr:`iou_aware` is false, + otherwise C should be equal to S * (6 + class_num). class_num is the object category number of source dataset(such as 80 in coco dataset), so the second(channel) dimension, apart from 4 box location coordinates x, y, w, h, also includes confidence score of the box and class one-hot key of each anchor @@ -183,6 +214,15 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { score_{pred} = score_{conf} * score_{class} $$ + where the confidence scores follow the formula bellow + + .. math:: + + score_{conf} = \begin{case} + obj, \text{if } iou_aware == flase \\ + obj^{1 - iou_aware_factor} * iou^{iou_aware_factor}, \text{otherwise} + \end{case} + )DOC"); } }; @@ -197,3 +237,12 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel, ops::YoloBoxKernel); + +REGISTER_OP_VERSION(yolo_box) + .AddCheckpoint( + R"ROC( + Upgrade yolo box to add new attribute [iou_aware, iou_aware_factor]. + )ROC", + paddle::framework::compatible::OpVersionDesc() + .NewAttr("iou_aware", "Whether use iou aware", false) + .NewAttr("iou_aware_factor", "iou aware factor", 0.5f)); diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index 65dc73ef38323521590c9f5914ac13b321ef4469..ef0b870ebfdf7874ea1e80f8716bc496f3aca890 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -28,7 +28,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, const int w, const int an_num, const int class_num, const int box_num, int input_size_h, int input_size_w, bool clip_bbox, const float scale, - const float bias) { + const float bias, bool iou_aware, + const float iou_aware_factor) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; T box[4]; @@ -43,23 +44,29 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int img_height = imgsize[2 * i]; int img_width = imgsize[2 * i + 1]; - int obj_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4); + int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4, + iou_aware); T conf = sigmoid(input[obj_idx]); + if (iou_aware) { + int iou_idx = GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num); + T iou = sigmoid(input[iou_idx]); + conf = pow(conf, static_cast(1. - iou_aware_factor)) * + pow(iou, static_cast(iou_aware_factor)); + } if (conf < conf_thresh) { continue; } - int box_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); + int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0, + iou_aware); GetYoloBox(box, input, anchors, l, k, j, h, w, input_size_h, input_size_w, box_idx, grid_num, img_height, img_width, scale, bias); box_idx = (i * box_num + j * grid_num + k * w + l) * 4; CalcDetectionBox(boxes, box, box_idx, img_height, img_width, clip_bbox); - int label_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); + int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, + 5, iou_aware); int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; CalcLabelScore(scores, input, label_idx, score_idx, class_num, conf, grid_num); @@ -80,6 +87,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { float conf_thresh = ctx.Attr("conf_thresh"); int downsample_ratio = ctx.Attr("downsample_ratio"); bool clip_bbox = ctx.Attr("clip_bbox"); + bool iou_aware = ctx.Attr("iou_aware"); + float iou_aware_factor = ctx.Attr("iou_aware_factor"); float scale = ctx.Attr("scale_x_y"); float bias = -0.5 * (scale - 1.); @@ -115,7 +124,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { ctx.cuda_device_context().stream()>>>( input_data, imgsize_data, boxes_data, scores_data, conf_thresh, anchors_data, n, h, w, an_num, class_num, box_num, input_size_h, - input_size_w, clip_bbox, scale, bias); + input_size_w, clip_bbox, scale, bias, iou_aware, iou_aware_factor); } }; diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index 1cfef142bca7327cb039412719b7c002beb53cab..e06c81052a0f42c9db4d96e49d2708e64e4f3137 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -13,6 +13,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { @@ -43,8 +44,19 @@ HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num, int an_stride, int stride, - int entry) { - return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; + int entry, bool iou_aware) { + if (iou_aware) { + return (batch * an_num + an_idx) * an_stride + + (batch * an_num + an_num + entry) * stride + hw_idx; + } else { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; + } +} + +HOSTDEVICE inline int GetIoUIndex(int batch, int an_idx, int hw_idx, int an_num, + int an_stride, int stride) { + return batch * an_num * an_stride + (batch * an_num + an_idx) * stride + + hw_idx; } template @@ -92,6 +104,8 @@ class YoloBoxKernel : public framework::OpKernel { float conf_thresh = ctx.Attr("conf_thresh"); int downsample_ratio = ctx.Attr("downsample_ratio"); bool clip_bbox = ctx.Attr("clip_bbox"); + bool iou_aware = ctx.Attr("iou_aware"); + float iou_aware_factor = ctx.Attr("iou_aware_factor"); float scale = ctx.Attr("scale_x_y"); float bias = -0.5 * (scale - 1.); @@ -127,15 +141,22 @@ class YoloBoxKernel : public framework::OpKernel { for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - int obj_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4); + int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, + stride, 4, iou_aware); T conf = sigmoid(input_data[obj_idx]); + if (iou_aware) { + int iou_idx = + GetIoUIndex(i, j, k * w + l, an_num, an_stride, stride); + T iou = sigmoid(input_data[iou_idx]); + conf = pow(conf, static_cast(1. - iou_aware_factor)) * + pow(iou, static_cast(iou_aware_factor)); + } if (conf < conf_thresh) { continue; } - int box_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); + int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, + stride, 0, iou_aware); GetYoloBox(box, input_data, anchors_data, l, k, j, h, w, input_size_h, input_size_w, box_idx, stride, img_height, img_width, scale, bias); @@ -143,8 +164,8 @@ class YoloBoxKernel : public framework::OpKernel { CalcDetectionBox(boxes_data, box, box_idx, img_height, img_width, clip_bbox); - int label_idx = - GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); + int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, + stride, 5, iou_aware); int score_idx = (i * box_num + j * stride + k * w + l) * class_num; CalcLabelScore(scores_data, input_data, label_idx, score_idx, class_num, conf, stride); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index cf4abc207bd7541676ee7ad3c1ad5f9c67a67619..604bcc0e277769e074b3c531fa364a62b8078e49 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -1139,7 +1139,9 @@ def yolo_box(x, downsample_ratio, clip_bbox=True, name=None, - scale_x_y=1.): + scale_x_y=1., + iou_aware=False, + iou_aware_factor=0.5): """ ${comment} @@ -1156,6 +1158,8 @@ def yolo_box(x, name (string): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` + iou_aware (bool): ${iou_aware_comment} + iou_aware_factor (float): ${iou_aware_factor_comment} Returns: Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, @@ -1204,6 +1208,8 @@ def yolo_box(x, "downsample_ratio": downsample_ratio, "clip_bbox": clip_bbox, "scale_x_y": scale_x_y, + "iou_aware": iou_aware, + "iou_aware_factor": iou_aware_factor } helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index 24c463ebfc9a1336c6a7eea19d4190db17d6f08c..5793f0148fc5475a89c3b53831bc2019af542b61 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -35,10 +35,16 @@ def YoloBox(x, img_size, attrs): downsample = attrs['downsample'] clip_bbox = attrs['clip_bbox'] scale_x_y = attrs['scale_x_y'] + iou_aware = attrs['iou_aware'] + iou_aware_factor = attrs['iou_aware_factor'] bias_x_y = -0.5 * (scale_x_y - 1.) input_h = downsample * h input_w = downsample * w + if iou_aware: + ioup = x[:, :an_num, :, :] + ioup = np.expand_dims(ioup, axis=-1) + x = x[:, an_num:, :, :] x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) pred_box = x[:, :, :, :, :4].copy() @@ -57,7 +63,11 @@ def YoloBox(x, img_size, attrs): pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h - pred_conf = sigmoid(x[:, :, :, :, 4:5]) + if iou_aware: + pred_conf = sigmoid(x[:, :, :, :, 4:5])**( + 1 - iou_aware_factor) * sigmoid(ioup)**iou_aware_factor + else: + pred_conf = sigmoid(x[:, :, :, :, 4:5]) pred_conf[pred_conf < conf_thresh] = 0. pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf pred_box = pred_box * (pred_conf > 0.).astype('float32') @@ -97,6 +107,8 @@ class TestYoloBoxOp(OpTest): "downsample": self.downsample, "clip_bbox": self.clip_bbox, "scale_x_y": self.scale_x_y, + "iou_aware": self.iou_aware, + "iou_aware_factor": self.iou_aware_factor } self.inputs = { @@ -123,6 +135,8 @@ class TestYoloBoxOp(OpTest): self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.imgsize_shape = (self.batch_size, 2) self.scale_x_y = 1. + self.iou_aware = False + self.iou_aware_factor = 0.5 class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): @@ -137,6 +151,8 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.imgsize_shape = (self.batch_size, 2) self.scale_x_y = 1. + self.iou_aware = False + self.iou_aware_factor = 0.5 class TestYoloBoxOpScaleXY(TestYoloBoxOp): @@ -151,19 +167,36 @@ class TestYoloBoxOpScaleXY(TestYoloBoxOp): self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.imgsize_shape = (self.batch_size, 2) self.scale_x_y = 1.2 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +class TestYoloBoxOpIoUAware(TestYoloBoxOp): + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int(len(self.anchors) // 2) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = True + self.x_shape = (self.batch_size, an_num * (6 + self.class_num), 13, 13) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1. + self.iou_aware = True + self.iou_aware_factor = 0.5 class TestYoloBoxDygraph(unittest.TestCase): def test_dygraph(self): paddle.disable_static() - x = np.random.random([2, 14, 8, 8]).astype('float32') img_size = np.ones((2, 2)).astype('int32') - - x = paddle.to_tensor(x) img_size = paddle.to_tensor(img_size) + x1 = np.random.random([2, 14, 8, 8]).astype('float32') + x1 = paddle.to_tensor(x1) boxes, scores = paddle.vision.ops.yolo_box( - x, + x1, img_size=img_size, anchors=[10, 13, 16, 30], class_num=2, @@ -172,16 +205,30 @@ class TestYoloBoxDygraph(unittest.TestCase): clip_bbox=True, scale_x_y=1.) assert boxes is not None and scores is not None + + x2 = np.random.random([2, 16, 8, 8]).astype('float32') + x2 = paddle.to_tensor(x2) + boxes, scores = paddle.vision.ops.yolo_box( + x2, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1., + iou_aware=True, + iou_aware_factor=0.5) paddle.enable_static() class TestYoloBoxStatic(unittest.TestCase): def test_static(self): - x = paddle.static.data('x', [2, 14, 8, 8], 'float32') + x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32') img_size = paddle.static.data('img_size', [2, 2], 'int32') boxes, scores = paddle.vision.ops.yolo_box( - x, + x1, img_size=img_size, anchors=[10, 13, 16, 30], class_num=2, @@ -191,6 +238,20 @@ class TestYoloBoxStatic(unittest.TestCase): scale_x_y=1.) assert boxes is not None and scores is not None + x2 = paddle.static.data('x2', [2, 16, 8, 8], 'float32') + boxes, scores = paddle.vision.ops.yolo_box( + x2, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1., + iou_aware=True, + iou_aware_factor=0.5) + assert boxes is not None and scores is not None + class TestYoloBoxOpHW(TestYoloBoxOp): def initTestCase(self): @@ -204,6 +265,8 @@ class TestYoloBoxOpHW(TestYoloBoxOp): self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 9) self.imgsize_shape = (self.batch_size, 2) self.scale_x_y = 1. + self.iou_aware = False + self.iou_aware_factor = 0.5 if __name__ == "__main__": diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 60a7a90c9be89591e681192f5e886f9c5443a8c0..769e33c73557916d43e71a3ea48064bb6993a949 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -247,7 +247,9 @@ def yolo_box(x, downsample_ratio, clip_bbox=True, name=None, - scale_x_y=1.): + scale_x_y=1., + iou_aware=False, + iou_aware_factor=0.5): r""" This operator generates YOLO detection boxes from output of YOLOv3 network. @@ -256,7 +258,8 @@ def yolo_box(x, should be the same, H and W specify the grid size, each grid point predict given number boxes, this given number, which following will be represented as S, is specified by the number of anchors. In the second dimension(the channel - dimension), C should be equal to S * (5 + class_num), class_num is the object + dimension), C should be equal to S * (5 + class_num) if :attr:`iou_aware` is false, + otherwise C should be equal to S * (6 + class_num). class_num is the object category number of source dataset(such as 80 in coco dataset), so the second(channel) dimension, apart from 4 box location coordinates x, y, w, h, also includes confidence score of the box and class one-hot key of each anchor @@ -292,6 +295,15 @@ def yolo_box(x, score_{pred} = score_{conf} * score_{class} $$ + where the confidence scores follow the formula bellow + + .. math:: + + score_{conf} = \begin{case} + obj, \text{if } iou_aware == flase \\ + obj^{1 - iou_aware_factor} * iou^{iou_aware_factor}, \text{otherwise} + \end{case} + Args: x (Tensor): The input tensor of YoloBox operator is a 4-D tensor with shape of [N, C, H, W]. The second dimension(C) stores box @@ -313,13 +325,14 @@ def yolo_box(x, should be set for the first, second, and thrid :attr:`yolo_box` layer. clip_bbox (bool): Whether clip output bonding box in :attr:`img_size` - boundary. Default true." - " + boundary. Default true. scale_x_y (float): Scale the center point of decoded bounding box. Default 1.0 name (string): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` + iou_aware (bool): Whether use iou aware. Default false + iou_aware_factor (float): iou aware factor. Default 0.5 Returns: Tensor: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, @@ -358,7 +371,8 @@ def yolo_box(x, boxes, scores = core.ops.yolo_box( x, img_size, 'anchors', anchors, 'class_num', class_num, 'conf_thresh', conf_thresh, 'downsample_ratio', downsample_ratio, - 'clip_bbox', clip_bbox, 'scale_x_y', scale_x_y) + 'clip_bbox', clip_bbox, 'scale_x_y', scale_x_y, 'iou_aware', + iou_aware, 'iou_aware_factor', iou_aware_factor) return boxes, scores helper = LayerHelper('yolo_box', **locals()) @@ -378,6 +392,8 @@ def yolo_box(x, "downsample_ratio": downsample_ratio, "clip_bbox": clip_bbox, "scale_x_y": scale_x_y, + "iou_aware": iou_aware, + "iou_aware_factor": iou_aware_factor } helper.append_op(