From 3fc2622d32d7eb0deb13c24b8902858f183376e9 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 29 Oct 2020 14:15:07 +0800 Subject: [PATCH] Add multiclass_nms,anchor_generator,prior_box in dygraph (#1627) --- ppdet/modeling/layers.py | 4 +- ppdet/modeling/ops.py | 440 +++++++++++++++++++++++++++++-- ppdet/modeling/tests/test_ops.py | 190 ++++++++++++- 3 files changed, 602 insertions(+), 32 deletions(-) diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index e65505908..f61992b11 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -45,7 +45,7 @@ class AnchorGeneratorRPN(object): stride = self.stride if ( level is None or self.anchor_start_size is None) else ( self.stride[0] * (2.**level), self.stride[1] * (2.**level)) - anchor, var = fluid.layers.anchor_generator( + anchor, var = ops.anchor_generator( input=input, anchor_sizes=anchor_sizes, aspect_ratios=self.aspect_ratios, @@ -367,7 +367,7 @@ class DecodeClipNms(object): @register @serializable class MultiClassNMS(object): - __op__ = fluid.layers.multiclass_nms + __op__ = ops.multiclass_nms __append_doc__ = True def __init__(self, diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index a9ae84643..a55f8efb2 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -27,13 +27,13 @@ from functools import reduce __all__ = [ 'roi_pool', 'roi_align', - #'prior_box', - #'anchor_generator', + 'prior_box', + 'anchor_generator', #'generate_proposals', 'iou_similarity', #'box_coder', 'yolo_box', - #'multiclass_nms', + 'multiclass_nms', 'distribute_fpn_proposals', 'collect_fpn_proposals', 'matrix_nms', @@ -84,6 +84,7 @@ def roi_pool(input, .. code-block:: python import paddle + from ppdet.modeling import ops paddle.enable_static() x = paddle.static.data( @@ -187,6 +188,7 @@ def roi_align(input, .. code-block:: python import paddle + from ppdet.modeling import ops paddle.enable_static() x = paddle.static.data( @@ -278,12 +280,12 @@ def iou_similarity(x, y, box_normalized=True, name=None): Examples: .. code-block:: python - import numpy as np import paddle + from ppdet.modeling import ops paddle.enable_static() - x = paddle.data(name='x', shape=[None, 4], dtype='float32') - y = paddle.data(name='y', shape=[None, 4], dtype='float32') + x = paddle.static.data(name='x', shape=[None, 4], dtype='float32') + y = paddle.static.data(name='y', shape=[None, 4], dtype='float32') iou = ops.iou_similarity(x=x, y=y) """ @@ -355,8 +357,8 @@ def collect_fpn_proposals(multi_rois, Examples: .. code-block:: python - import paddle.fluid as fluid import paddle + from ppdet.modeling import ops paddle.enable_static() multi_rois = [] multi_scores = [] @@ -367,7 +369,7 @@ def collect_fpn_proposals(multi_rois, multi_scores.append(paddle.static.data( name='score_'+str(i), shape=[None, 1], dtype='float32', lod_level=1)) - fpn_rois = fluid.layers.collect_fpn_proposals( + fpn_rois = ops.collect_fpn_proposals( multi_rois=multi_rois, multi_scores=multi_scores, min_level=2, @@ -475,12 +477,12 @@ def distribute_fpn_proposals(fpn_rois, Examples: .. code-block:: python - import paddle.fluid as fluid import paddle + from ppdet.modeling import ops paddle.enable_static() fpn_rois = paddle.static.data( name='data', shape=[None, 4], dtype='float32', lod_level=1) - multi_rois, restore_ind = fluid.layers.distribute_fpn_proposals( + multi_rois, restore_ind = ops.distribute_fpn_proposals( fpn_rois=fpn_rois, min_level=2, max_level=5, @@ -621,6 +623,7 @@ def yolo_box( .. code-block:: python import paddle + from ppdet.modeling import ops paddle.enable_static() x = paddle.static.data(name='x', shape=[None, 255, 13, 13], dtype='float32') @@ -671,6 +674,411 @@ def yolo_box( return boxes, scores +def prior_box(input, + image, + min_sizes, + max_sizes=None, + aspect_ratios=[1.], + variance=[0.1, 0.1, 0.2, 0.2], + flip=False, + clip=False, + steps=[0.0, 0.0], + offset=0.5, + min_max_aspect_ratios_order=False, + name=None): + """ + + This op generates prior boxes for SSD(Single Shot MultiBox Detector) algorithm. + Each position of the input produce N prior boxes, N is determined by + the count of min_sizes, max_sizes and aspect_ratios, The size of the + box is in range(min_size, max_size) interval, which is generated in + sequence according to the aspect_ratios. + + Parameters: + input(Tensor): 4-D tensor(NCHW), the data type should be float32 or float64. + image(Tensor): 4-D tensor(NCHW), the input image data of PriorBoxOp, + the data type should be float32 or float64. + min_sizes(list|tuple|float): the min sizes of generated prior boxes. + max_sizes(list|tuple|None): the max sizes of generated prior boxes. + Default: None. + aspect_ratios(list|tuple|float): the aspect ratios of generated + prior boxes. Default: [1.]. + variance(list|tuple): the variances to be encoded in prior boxes. + Default:[0.1, 0.1, 0.2, 0.2]. + flip(bool): Whether to flip aspect ratios. Default:False. + clip(bool): Whether to clip out-of-boundary boxes. Default: False. + step(list|tuple): Prior boxes step across width and height, If + step[0] equals to 0.0 or step[1] equals to 0.0, the prior boxes step across + height or weight of the input will be automatically calculated. + Default: [0., 0.] + offset(float): Prior boxes center offset. Default: 0.5 + min_max_aspect_ratios_order(bool): If set True, the output prior box is + in order of [min, max, aspect_ratios], which is consistent with + Caffe. Please note, this order affects the weights order of + convolution layer followed by and does not affect the final + detection results. Default: False. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tuple: A tuple with two Variable (boxes, variances) + + boxes(Tensor): the output prior boxes of PriorBox. + 4-D tensor, the layout is [H, W, num_priors, 4]. + H is the height of input, W is the width of input, + num_priors is the total box count of each position of input. + + variances(Tensor): the expanded variances of PriorBox. + 4-D tensor, the layput is [H, W, num_priors, 4]. + H is the height of input, W is the width of input + num_priors is the total box count of each position of input + + Examples: + .. code-block:: python + + import paddle + from ppdet.modeling import ops + + paddle.enable_static() + input = paddle.static.data(name="input", shape=[None,3,6,9]) + image = paddle.static.data(name="image", shape=[None,3,9,12]) + box, var = ops.prior_box( + input=input, + image=image, + min_sizes=[100.], + clip=True, + flip=True) + """ + helper = LayerHelper("prior_box", **locals()) + dtype = helper.input_dtype() + check_variable_and_dtype( + input, 'input', ['uint8', 'int8', 'float32', 'float64'], 'prior_box') + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if not _is_list_or_tuple_(min_sizes): + min_sizes = [min_sizes] + if not _is_list_or_tuple_(aspect_ratios): + aspect_ratios = [aspect_ratios] + if not (_is_list_or_tuple_(steps) and len(steps) == 2): + raise ValueError('steps should be a list or tuple ', + 'with length 2, (step_width, step_height).') + + min_sizes = list(map(float, min_sizes)) + aspect_ratios = list(map(float, aspect_ratios)) + steps = list(map(float, steps)) + + cur_max_sizes = None + if max_sizes is not None and len(max_sizes) > 0 and max_sizes[0] > 0: + if not _is_list_or_tuple_(max_sizes): + max_sizes = [max_sizes] + cur_max_sizes = max_sizes + + if in_dygraph_mode(): + attrs = [ + 'min_sizes', min_sizes, 'aspect_ratios', aspect_ratios, 'variances', + variance, 'flip', flip, 'clip', clip, 'step_w', steps[0], 'step_h', + steps[1], 'offset', offset, 'min_max_aspect_ratios_order', + min_max_aspect_ratios_order + ] + if cur_max_sizes is not None: + attrs.extend('max_sizes', max_sizes) + attrs = tuple(attrs) + box, var = core.ops.prior_box(input, image, *attrs) + return box, var + + attrs = { + 'min_sizes': min_sizes, + 'aspect_ratios': aspect_ratios, + 'variances': variance, + 'flip': flip, + 'clip': clip, + 'step_w': steps[0], + 'step_h': steps[1], + 'offset': offset, + 'min_max_aspect_ratios_order': min_max_aspect_ratios_order + } + + if cur_max_sizes is not None: + attrs['max_sizes'] = cur_max_sizes + + box = helper.create_variable_for_type_inference(dtype) + var = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="prior_box", + inputs={"Input": input, + "Image": image}, + outputs={"Boxes": box, + "Variances": var}, + attrs=attrs, ) + box.stop_gradient = True + var.stop_gradient = True + return box, var + + +def anchor_generator(input, + anchor_sizes=None, + aspect_ratios=None, + variance=[0.1, 0.1, 0.2, 0.2], + stride=None, + offset=0.5, + name=None): + """ + + This op generate anchors for Faster RCNN algorithm. + Each position of the input produce N anchors, N = + size(anchor_sizes) * size(aspect_ratios). The order of generated anchors + is firstly aspect_ratios loop then anchor_sizes loop. + + Args: + input(Tensor): 4-D Tensor with shape [N,C,H,W]. The input feature map. + anchor_sizes(float32|list|tuple, optional): The anchor sizes of generated + anchors, given in absolute pixels e.g. [64., 128., 256., 512.]. + For instance, the anchor size of 64 means the area of this anchor + equals to 64**2. None by default. + aspect_ratios(float32|list|tuple, optional): The height / width ratios + of generated anchors, e.g. [0.5, 1.0, 2.0]. None by default. + variance(list|tuple, optional): The variances to be used in box + regression deltas. The data type is float32, [0.1, 0.1, 0.2, 0.2] by + default. + stride(list|tuple, optional): The anchors stride across width and height. + The data type is float32. e.g. [16.0, 16.0]. None by default. + offset(float32, optional): Prior boxes center offset. 0.5 by default. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and None + by default. + + Returns: + Tuple: + + Anchors(Tensor): The output anchors with a layout of [H, W, num_anchors, 4]. + H is the height of input, W is the width of input, + num_anchors is the box count of each position. + Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized. + + Variances(Tensor): The expanded variances of anchors + with a layout of [H, W, num_priors, 4]. + H is the height of input, W is the width of input + num_anchors is the box count of each position. + Each variance is in (xcenter, ycenter, w, h) format. + + + Examples: + + .. code-block:: python + + import paddle + from ppdet.modeling import ops + + paddle.enable_static() + conv1 = paddle.static.data(name='input', shape=[None, 48, 16, 16], dtype='float32') + anchor, var = ops.anchor_generator( + input=conv1, + anchor_sizes=[64, 128, 256, 512], + aspect_ratios=[0.5, 1.0, 2.0], + variance=[0.1, 0.1, 0.2, 0.2], + stride=[16.0, 16.0], + offset=0.5) + """ + helper = LayerHelper("anchor_generator", **locals()) + dtype = helper.input_dtype() + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if not _is_list_or_tuple_(anchor_sizes): + anchor_sizes = [anchor_sizes] + if not _is_list_or_tuple_(aspect_ratios): + aspect_ratios = [aspect_ratios] + if not (_is_list_or_tuple_(stride) and len(stride) == 2): + raise ValueError('stride should be a list or tuple ', + 'with length 2, (stride_width, stride_height).') + + anchor_sizes = list(map(float, anchor_sizes)) + aspect_ratios = list(map(float, aspect_ratios)) + stride = list(map(float, stride)) + + if in_dygraph_mode(): + attrs = ('anchor_sizes', anchor_sizes, 'aspect_ratios', aspect_ratios, + 'variances', variance, 'stride', stride, 'offset', offset) + anchor, var = core.ops.anchor_generator(input, *attrs) + return anchor, var + + attrs = { + 'anchor_sizes': anchor_sizes, + 'aspect_ratios': aspect_ratios, + 'variances': variance, + 'stride': stride, + 'offset': offset + } + + anchor = helper.create_variable_for_type_inference(dtype) + var = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="anchor_generator", + inputs={"Input": input}, + outputs={"Anchors": anchor, + "Variances": var}, + attrs=attrs, ) + anchor.stop_gradient = True + var.stop_gradient = True + return anchor, var + + +def multiclass_nms(bboxes, + scores, + score_threshold, + nms_top_k, + keep_top_k, + nms_threshold=0.3, + normalized=True, + nms_eta=1., + background_label=0, + return_index=False, + rois_num=None, + name=None): + """ + This operator is to do multi-class non maximum suppression (NMS) on + boxes and scores. + In the NMS step, this operator greedily selects a subset of detection bounding + boxes that have high scores larger than score_threshold, if providing this + threshold, then selects the largest nms_top_k confidences scores if nms_top_k + is larger than -1. Then this operator pruns away boxes that have high IOU + (intersection over union) overlap with already selected boxes by adaptive + threshold NMS based on parameters of nms_threshold and nms_eta. + Aftern NMS step, at most keep_top_k number of total bboxes are to be kept + per image if keep_top_k is larger than -1. + Args: + bboxes (Tensor): Two types of bboxes are supported: + 1. (Tensor) A 3-D Tensor with shape + [N, M, 4 or 8 16 24 32] represents the + predicted locations of M bounding bboxes, + N is the batch size. Each bounding box has four + coordinate values and the layout is + [xmin, ymin, xmax, ymax], when box size equals to 4. + 2. (LoDTensor) A 3-D Tensor with shape [M, C, 4] + M is the number of bounding boxes, C is the + class number + scores (Tensor): Two types of scores are supported: + 1. (Tensor) A 3-D Tensor with shape [N, C, M] + represents the predicted confidence predictions. + N is the batch size, C is the class number, M is + number of bounding boxes. For each category there + are total M scores which corresponding M bounding + boxes. Please note, M is equal to the 2nd dimension + of BBoxes. + 2. (LoDTensor) A 2-D LoDTensor with shape [M, C]. + M is the number of bbox, C is the class number. + In this case, input BBoxes should be the second + case with shape [M, C, 4]. + background_label (int): The index of background label, the background + label will be ignored. If set to -1, then all + categories will be considered. Default: 0 + score_threshold (float): Threshold to filter out bounding boxes with + low confidence score. If not provided, + consider all boxes. + nms_top_k (int): Maximum number of detections to be kept according to + the confidences after the filtering detections based + on score_threshold. + nms_threshold (float): The threshold to be used in NMS. Default: 0.3 + nms_eta (float): The threshold to be used in NMS. Default: 1.0 + keep_top_k (int): Number of total bboxes to be kept per image after NMS + step. -1 means keeping all bboxes after NMS step. + normalized (bool): Whether detections are normalized. Default: True + return_index(bool): Whether return selected index. Default: False + rois_num(Tensor): 1-D Tensor contains the number of RoIs in each image. + The shape is [B] and data type is int32. B is the number of images. + If it is not None then return a list of 1-D Tensor. Each element + is the output RoIs' number of each image on the corresponding level + and the shape is [B]. None by default. + name(str): Name of the multiclass nms op. Default: None. + Returns: + A tuple with two Variables: (Out, Index) if return_index is True, + otherwise, a tuple with one Variable(Out) is returned. + Out: A 2-D LoDTensor with shape [No, 6] represents the detections. + Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax] + or A 2-D LoDTensor with shape [No, 10] represents the detections. + Each row has 10 values: [label, confidence, x1, y1, x2, y2, x3, y3, + x4, y4]. No is the total number of detections. + If all images have not detected results, all elements in LoD will be + 0, and output tensor is empty (None). + Index: Only return when return_index is True. A 2-D LoDTensor with + shape [No, 1] represents the selected index which type is Integer. + The index is the absolute value cross batches. No is the same number + as Out. If the index is used to gather other attribute such as age, + one needs to reshape the input(N, M, 1) to (N * M, 1) as first, where + N is the batch size and M is the number of boxes. + Examples: + .. code-block:: python + + import paddle + from ppdet.modeling import ops + boxes = paddle.static.data(name='bboxes', shape=[81, 4], + dtype='float32', lod_level=1) + scores = paddle.static.data(name='scores', shape=[81], + dtype='float32', lod_level=1) + out, index = ops.multiclass_nms(bboxes=boxes, + scores=scores, + background_label=0, + score_threshold=0.5, + nms_top_k=400, + nms_threshold=0.3, + keep_top_k=200, + normalized=False, + return_index=True) + """ + helper = LayerHelper('multiclass_nms3', **locals()) + + if in_dygraph_mode(): + assert rois_num is not None, "rois_num should not be None in dygraph mode." + attrs = ('background_label', background_label, 'score_threshold', + score_threshold, 'nms_top_k', nms_top_k, 'nms_threshold', + nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta, + 'normalized', normalized) + output, index, nms_rois_num = core.ops.multiclass_nms3(bboxes, scores, + rois_num, *attrs) + if return_index: + return output, index, nms_rois_num + else: + return output, nms_rois_num + + output = helper.create_variable_for_type_inference(dtype=bboxes.dtype) + index = helper.create_variable_for_type_inference(dtype='int') + + inputs = {'BBoxes': bboxes, 'Scores': scores} + outputs = {'Out': output, 'Index': index} + + if rois_num is not None: + inputs['RoisNum'] = rois_num + nms_rois_num = helper.create_variable_for_type_inference(dtype='int32') + outputs['NmsRoisNum'] = nms_rois_num + + helper.append_op( + type="multiclass_nms3", + inputs=inputs, + attrs={ + 'background_label': background_label, + 'score_threshold': score_threshold, + 'nms_top_k': nms_top_k, + 'nms_threshold': nms_threshold, + 'keep_top_k': keep_top_k, + 'nms_eta': nms_eta, + 'normalized': normalized + }, + outputs=outputs) + output.stop_gradient = True + index.stop_gradient = True + + if return_index and rois_num is not None: + return output, index, nms_rois_num + elif return_index and rois_num is None: + return output, index + elif not return_index and rois_num is not None: + return output, nms_rois_num + return output + + def matrix_nms(bboxes, scores, score_threshold, @@ -686,16 +1094,13 @@ def matrix_nms(bboxes, name=None): """ **Matrix NMS** - This operator does matrix non maximum suppression (NMS). - First selects a subset of candidate bounding boxes that have higher scores than score_threshold (if provided), then the top k candidate is selected if nms_top_k is larger than -1. Score of the remaining candidate are then decayed according to the Matrix NMS scheme. Aftern NMS step, at most keep_top_k number of total bboxes are to be kept per image if keep_top_k is larger than -1. - Args: bboxes (Tensor): A 3-D Tensor with shape [N, M, 4] represents the predicted locations of M bounding bboxes, @@ -728,30 +1133,22 @@ def matrix_nms(bboxes, return_index(bool): Whether return selected index. Default: False return_rois_num(bool): whether return rois_num. Default: True name(str): Name of the matrix nms op. Default: None. - Returns: A tuple with three Tensor: (Out, Index, RoisNum) if return_index is True, otherwise, a tuple with two Tensor (Out, RoisNum) is returned. - Out (Tensor): A 2-D Tensor with shape [No, 6] containing the detection results. Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax] (After version 1.3, when no boxes detected, the lod is changed from {0} to {1}) - Index (Tensor): A 2-D Tensor with shape [No, 1] containing the selected indices, which are absolute values cross batches. - rois_num (Tensor): A 1-D Tensor with shape [N] containing the number of detected boxes in each image. - Examples: .. code-block:: python - - import paddle from ppdet.modeling import ops - boxes = paddle.static.data(name='bboxes', shape=[None,81, 4], dtype='float32', lod_level=1) scores = paddle.static.data(name='scores', shape=[None,81], @@ -759,7 +1156,6 @@ def matrix_nms(bboxes, out = ops.matrix_nms(bboxes=boxes, scores=scores, background_label=0, score_threshold=0.5, post_threshold=0.1, nms_top_k=400, keep_top_k=200, normalized=False) - """ check_variable_and_dtype(bboxes, 'BBoxes', ['float32', 'float64'], 'matrix_nms') diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py index 87cbf8dc6..d2097df27 100644 --- a/ppdet/modeling/tests/test_ops.py +++ b/ppdet/modeling/tests/test_ops.py @@ -383,7 +383,7 @@ class TestIoUSimilarity(LayerTest): self.assertTrue(np.array_equal(iou_np, iou_dy_np)) -class TestYOLO_Box(LayerTest): +class TestYOLOBox(LayerTest): def test_yolo_box(self): # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2 @@ -414,11 +414,6 @@ class TestYOLO_Box(LayerTest): feed={ 'x': np_x, 'origin_shape': np_origin_shape, - 'anchors': [10, 13, 30, 13], - 'class_num': 10, - 'conf_thresh': 0.01, - 'downsample_ratio': 32, - 'scale_x_y': 1.0, }, fetch_list=[boxes, scores], with_lod=False) @@ -439,8 +434,8 @@ class TestYOLO_Box(LayerTest): boxes_dy_np = boxes_dy.numpy() scores_dy_np = scores_dy.numpy() - self.assertTrue(np.array_equal(boxes_np, boxes_dy_np)) - self.assertTrue(np.array_equal(scores_np, scores_dy_np)) + self.assertTrue(np.array_equal(boxes_np, boxes_dy_np)) + self.assertTrue(np.array_equal(scores_np, scores_dy_np)) def test_yolo_box_error(self): paddle.enable_static() @@ -463,6 +458,185 @@ class TestYOLO_Box(LayerTest): scale_x_y=1.2) +class TestPriorBox(LayerTest): + def test_prior_box(self): + input_np = np.random.rand(2, 10, 32, 32).astype('float32') + image_np = np.random.rand(2, 10, 40, 40).astype('float32') + min_sizes = [2, 4] + with self.static_graph(): + input = paddle.static.data( + name='input', shape=[2, 10, 32, 32], dtype='float32') + image = paddle.static.data( + name='image', shape=[2, 10, 40, 40], dtype='float32') + + box, var = ops.prior_box( + input=input, + image=image, + min_sizes=min_sizes, + clip=True, + flip=True) + box_np, var_np = self.get_static_graph_result( + feed={ + 'input': input_np, + 'image': image_np, + }, + fetch_list=[box, var], + with_lod=False) + + with self.dynamic_graph(): + inputs_dy = base.to_variable(input_np) + image_dy = base.to_variable(image_np) + + box_dy, var_dy = ops.prior_box( + input=inputs_dy, + image=image_dy, + min_sizes=min_sizes, + clip=True, + flip=True) + box_dy_np = box_dy.numpy() + var_dy_np = var_dy.numpy() + + self.assertTrue(np.array_equal(box_np, box_dy_np)) + self.assertTrue(np.array_equal(var_np, var_dy_np)) + + def test_prior_box_error(self): + program = Program() + paddle.enable_static() + with program_guard(program): + input = paddle.static.data( + name='input', shape=[2, 10, 32, 32], dtype='int32') + image = paddle.static.data( + name='image', shape=[2, 10, 40, 40], dtype='int32') + self.assertRaises( + TypeError, + ops.prior_box, + input=input, + image=image, + min_sizes=[2, 4], + clip=True, + flip=True) + + +class TestAnchorGenerator(LayerTest): + def test_anchor_generator(self): + b, c, h, w = 2, 48, 16, 16 + input_np = np.random.rand(2, 48, 16, 16).astype('float32') + paddle.enable_static() + with self.static_graph(): + input = paddle.static.data( + name='input', shape=[b, c, h, w], dtype='float32') + + anchor, var = ops.anchor_generator( + input=input, + anchor_sizes=[64, 128, 256, 512], + aspect_ratios=[0.5, 1.0, 2.0], + variance=[0.1, 0.1, 0.2, 0.2], + stride=[16.0, 16.0], + offset=0.5) + anchor_np, var_np = self.get_static_graph_result( + feed={'input': input_np, }, + fetch_list=[anchor, var], + with_lod=False) + + with self.dynamic_graph(): + inputs_dy = base.to_variable(input_np) + + anchor_dy, var_dy = ops.anchor_generator( + input=inputs_dy, + anchor_sizes=[64, 128, 256, 512], + aspect_ratios=[0.5, 1.0, 2.0], + variance=[0.1, 0.1, 0.2, 0.2], + stride=[16.0, 16.0], + offset=0.5) + anchor_dy_np = anchor_dy.numpy() + var_dy_np = var_dy.numpy() + + self.assertTrue(np.array_equal(anchor_np, anchor_dy_np)) + self.assertTrue(np.array_equal(var_np, var_dy_np)) + + +class TestMulticlassNms(LayerTest): + def test_multiclass_nms(self): + boxes_np = np.random.rand(81, 4).astype('float32') + scores_np = np.random.rand(81).astype('float32') + rois_num_np = np.array([40, 41]).astype('int32') + with self.static_graph(): + boxes = paddle.static.data( + name='bboxes', shape=[81, 4], dtype='float32', lod_level=1) + scores = paddle.static.data( + name='scores', shape=[81], dtype='float32', lod_level=1) + rois_num = paddle.static.data( + name='rois_num', shape=[40, 41], dtype='int32') + + output = ops.multiclass_nms( + bboxes=boxes, + scores=scores, + background_label=0, + score_threshold=0.5, + nms_top_k=400, + nms_threshold=0.3, + keep_top_k=200, + normalized=False, + return_index=True, + rois_num=rois_num) + out_np, index_np, nms_rois_num_np = self.get_static_graph_result( + feed={ + 'bboxes': boxes_np, + 'scores': scores_np, + 'rois_num': rois_num_np + }, + fetch_list=output, + with_lod=False) + + with self.dynamic_graph(): + boxes_dy = base.to_variable(boxes_np) + scores_dy = base.to_variable(scores_np) + rois_num_dy = base.to_variable(rois_num_np) + + out_dy, index_dy, nms_rois_num_dy = ops.multiclass_nms( + bboxes=boxes_dy, + scores=scores_dy, + background_label=0, + score_threshold=0.5, + nms_top_k=400, + nms_threshold=0.3, + keep_top_k=200, + normalized=False, + return_index=True, + rois_num=rois_num_dy) + out_dy_np = out_dy.numpy() + index_dy_np = index_dy.numpy() + nms_rois_num_dy_np = nms_rois_num_dy.numpy() + + self.assertTrue(np.array_equal(out_np, out_dy_np)) + self.assertTrue(np.array_equal(index_np, index_dy_np)) + self.assertTrue(np.array_equal(nms_rois_num_np, nms_rois_num_dy_np)) + + def test_multiclass_nms_error(self): + program = Program() + paddle.enable_static() + with program_guard(program): + boxes = paddle.static.data( + name='bboxes', shape=[81, 4], dtype='float32', lod_level=1) + scores = paddle.static.data( + name='scores', shape=[81], dtype='float32', lod_level=1) + rois_num = paddle.static.data( + name='rois_num', shape=[40, 41], dtype='int32') + self.assertRaises( + TypeError, + ops.multiclass_nms, + boxes=boxes, + scores=scores, + background_label=0, + score_threshold=0.5, + nms_top_k=400, + nms_threshold=0.3, + keep_top_k=200, + normalized=False, + return_index=True, + rois_num=rois_num) + + class TestMatrixNMS(LayerTest): def test_matrix_nms(self): N, M, C = 7, 1200, 21 -- GitLab