diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 653bb519aca2ebd6b02ca293c0aa22d946ae5633..caae2597a5401f7d703429ac68afdfea3d1df362 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -32,7 +32,7 @@ __all__ = [ #'generate_proposals', 'iou_similarity', #'box_coder', - #'yolo_box', + 'yolo_box', #'multiclass_nms', 'distribute_fpn_proposals', 'collect_fpn_proposals', @@ -534,3 +534,138 @@ def distribute_fpn_proposals(fpn_rois, if rois_num is not None: return multi_rois, restore_ind, rois_num_per_level return multi_rois, restore_ind + + +def yolo_box( + x, + origin_shape, + anchors, + class_num, + conf_thresh, + downsample_ratio, + clip_bbox=True, + scale_x_y=1., + name=None, ): + """ + + This operator generates YOLO detection boxes from output of YOLOv3 network. + + The output of previous network is in shape [N, C, H, W], while H and W + should be the same, H and W specify the grid size, each grid point predict + given number boxes, this given number, which following will be represented as S, + is specified by the number of anchors. In the second dimension(the channel + dimension), C should be equal to S * (5 + class_num), class_num is the object + category number of source dataset(such as 80 in coco dataset), so the + second(channel) dimension, apart from 4 box location coordinates x, y, w, h, + also includes confidence score of the box and class one-hot key of each anchor + box. + Assume the 4 location coordinates are :math:`t_x, t_y, t_w, t_h`, the box + predictions should be as follows: + $$ + b_x = \\sigma(t_x) + c_x + $$ + $$ + b_y = \\sigma(t_y) + c_y + $$ + $$ + b_w = p_w e^{t_w} + $$ + $$ + b_h = p_h e^{t_h} + $$ + in the equation above, :math:`c_x, c_y` is the left top corner of current grid + and :math:`p_w, p_h` is specified by anchors. + The logistic regression value of the 5th channel of each anchor prediction boxes + represents the confidence score of each prediction box, and the logistic + regression value of the last :attr:`class_num` channels of each anchor prediction + boxes represents the classifcation scores. Boxes with confidence scores less than + :attr:`conf_thresh` should be ignored, and box final scores is the product of + confidence scores and classification scores. + $$ + score_{pred} = score_{conf} * score_{class} + $$ + + Args: + x (Tensor): The input tensor of YoloBox operator is a 4-D tensor with shape of [N, C, H, W]. + The second dimension(C) stores box locations, confidence score and + classification one-hot keys of each anchor box. Generally, X should be the output of YOLOv3 network. + The data type is float32 or float64. + origin_shape (Tensor): The image size tensor of YoloBox operator, This is a 2-D tensor with shape of [N, 2]. + This tensor holds height and width of each input image used for resizing output box in input image + scale. The data type is int32. + anchors (list|tuple): The anchor width and height, it will be parsed pair by pair. + class_num (int): The number of classes to predict. + conf_thresh (float): The confidence scores threshold of detection boxes. Boxes with confidence scores + under threshold should be ignored. + downsample_ratio (int): The downsample ratio from network input to YoloBox operator input, + so 32, 16, 8 should be set for the first, second, and thrid YoloBox operators. + clip_bbox (bool): Whether clip output bonding box in Input(ImgSize) boundary. Default true. + scale_x_y (float): Scale the center point of decoded bounding box. Default 1.0. + name (string): The default value is None. Normally there is no need + for user to set this property. For more information, + please refer to :ref:`api_guide_Name` + + Returns: + boxes Tensor: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, N is the batch num, + M is output box number, and the 3rd dimension stores [xmin, ymin, xmax, ymax] coordinates of boxes. + scores Tensor: A 3-D tensor with shape [N, M, :attr:`class_num`], the coordinates of boxes, N is the batch num, + M is output box number. + + Raises: + TypeError: Attr anchors of yolo box must be list or tuple + TypeError: Attr class_num of yolo box must be an integer + TypeError: Attr conf_thresh of yolo box must be a float number + + Examples: + + .. code-block:: python + + import paddle + + paddle.enable_static() + x = paddle.static.data(name='x', shape=[None, 255, 13, 13], dtype='float32') + img_size = paddle.static.data(name='img_size',shape=[None, 2],dtype='int64') + anchors = [10, 13, 16, 30, 33, 23] + boxes,scores = ops.yolo_box(x=x, img_size=img_size, class_num=80, anchors=anchors, + conf_thresh=0.01, downsample_ratio=32) + """ + helper = LayerHelper('yolo_box', **locals()) + + if not isinstance(anchors, list) and not isinstance(anchors, tuple): + raise TypeError("Attr anchors of yolo_box must be list or tuple") + if not isinstance(class_num, int): + raise TypeError("Attr class_num of yolo_box must be an integer") + if not isinstance(conf_thresh, float): + raise TypeError("Attr ignore_thresh of yolo_box must be a float number") + + if in_dygraph_mode(): + attrs = ('anchors', anchors, 'class_num', class_num, 'conf_thresh', + conf_thresh, 'downsample_ratio', downsample_ratio, 'clip_bbox', + clip_bbox, 'scale_x_y', scale_x_y) + boxes, scores = core.ops.yolo_box(x, origin_shape, *attrs) + return boxes, scores + + boxes = helper.create_variable_for_type_inference(dtype=x.dtype) + scores = helper.create_variable_for_type_inference(dtype=x.dtype) + + attrs = { + "anchors": anchors, + "class_num": class_num, + "conf_thresh": conf_thresh, + "downsample_ratio": downsample_ratio, + "clip_bbox": clip_bbox, + "scale_x_y": scale_x_y, + } + + helper.append_op( + type='yolo_box', + inputs={ + "X": x, + "ImgSize": origin_shape, + }, + outputs={ + 'Boxes': boxes, + 'Scores': scores, + }, + attrs=attrs) + return boxes, scores diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py index 7c290b88d9baaa11c62e681a978a6e4cdf6d1482..3fc3f0c885ff25f0cba4f74d107d284c8e68fb59 100644 --- a/ppdet/modeling/tests/test_ops.py +++ b/ppdet/modeling/tests/test_ops.py @@ -355,15 +355,19 @@ class TestIoUSimilarity(LayerTest): x_np = make_rois(h, w, [20], output_size) y_np = make_rois(h, w, [10], output_size) with self.static_graph(): - x = paddle.static.data(name='x', shape=[20, 4], dtype='float32') - y = paddle.static.data(name='y', shape=[10, 4], dtype='float32') - - iou = ops.iou_similarity(x=x, y=y) - iou_np, = self.get_static_graph_result( - feed={ - 'x': x_np, - 'y': y_np, - }, fetch_list=[iou], with_lod=False) + program = Program() + with program_guard(program): + x = paddle.static.data(name='x', shape=[20, 4], dtype='float32') + y = paddle.static.data(name='y', shape=[10, 4], dtype='float32') + + iou = ops.iou_similarity(x=x, y=y) + iou_np, = self.get_static_graph_result( + feed={ + 'x': x_np, + 'y': y_np, + }, + fetch_list=[iou], + with_lod=False) with self.dynamic_graph(): x_dy = base.to_variable(x_np) @@ -375,5 +379,85 @@ class TestIoUSimilarity(LayerTest): self.assertTrue(np.array_equal(iou_np, iou_dy_np)) +class TestYOLO_Box(LayerTest): + def test_yolo_box(self): + + # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2 + np_x = np.random.random([1, 30, 7, 7]).astype('float32') + np_origin_shape = np.array([[608, 608]], dtype='int32') + class_num = 10 + conf_thresh = 0.01 + downsample_ratio = 32 + scale_x_y = 1.2 + + # static + with self.static_graph(): + # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2 + x = paddle.static.data( + name='x', shape=[1, 30, 7, 7], dtype='float32') + origin_shape = paddle.static.data( + name='origin_shape', shape=[1, 2], dtype='int32') + + boxes, scores = ops.yolo_box( + x, + origin_shape, [10, 13, 30, 13], + class_num, + conf_thresh, + downsample_ratio, + scale_x_y=scale_x_y) + + boxes_np, scores_np = self.get_static_graph_result( + feed={ + 'x': np_x, + 'origin_shape': np_origin_shape, + 'anchors': [10, 13, 30, 13], + 'class_num': 10, + 'conf_thresh': 0.01, + 'downsample_ratio': 32, + 'scale_x_y': 1.0, + }, + fetch_list=[boxes, scores], + with_lod=False) + + # dygraph + with self.dynamic_graph(): + x_dy = fluid.layers.assign(np_x) + origin_shape_dy = fluid.layers.assign(np_origin_shape) + + boxes_dy, scores_dy = ops.yolo_box( + x_dy, + origin_shape_dy, [10, 13, 30, 13], + 10, + 0.01, + 32, + scale_x_y=scale_x_y) + + boxes_dy_np = boxes_dy.numpy() + scores_dy_np = scores_dy.numpy() + + self.assertTrue(np.array_equal(boxes_np, boxes_dy_np)) + self.assertTrue(np.array_equal(scores_np, scores_dy_np)) + + def test_yolo_box_error(self): + paddle.enable_static() + program = Program() + with program_guard(program): + # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2 + x = paddle.static.data( + name='x', shape=[1, 30, 7, 7], dtype='float32') + origin_shape = paddle.static.data( + name='origin_shape', shape=[1, 2], dtype='int32') + + self.assertRaises( + TypeError, + ops.yolo_box, + x, + origin_shape, [10, 13, 30, 13], + 10.123, + 0.01, + 32, + scale_x_y=1.2) + + if __name__ == '__main__': unittest.main()