未验证 提交 3fc2622d 编写于 作者: G Guanghua Yu 提交者: GitHub

Add multiclass_nms,anchor_generator,prior_box in dygraph (#1627)

上级 eed296b0
......@@ -45,7 +45,7 @@ class AnchorGeneratorRPN(object):
stride = self.stride if (
level is None or self.anchor_start_size is None) else (
self.stride[0] * (2.**level), self.stride[1] * (2.**level))
anchor, var = fluid.layers.anchor_generator(
anchor, var = ops.anchor_generator(
input=input,
anchor_sizes=anchor_sizes,
aspect_ratios=self.aspect_ratios,
......@@ -367,7 +367,7 @@ class DecodeClipNms(object):
@register
@serializable
class MultiClassNMS(object):
__op__ = fluid.layers.multiclass_nms
__op__ = ops.multiclass_nms
__append_doc__ = True
def __init__(self,
......
......@@ -27,13 +27,13 @@ from functools import reduce
__all__ = [
'roi_pool',
'roi_align',
#'prior_box',
#'anchor_generator',
'prior_box',
'anchor_generator',
#'generate_proposals',
'iou_similarity',
#'box_coder',
'yolo_box',
#'multiclass_nms',
'multiclass_nms',
'distribute_fpn_proposals',
'collect_fpn_proposals',
'matrix_nms',
......@@ -84,6 +84,7 @@ def roi_pool(input,
.. code-block:: python
import paddle
from ppdet.modeling import ops
paddle.enable_static()
x = paddle.static.data(
......@@ -187,6 +188,7 @@ def roi_align(input,
.. code-block:: python
import paddle
from ppdet.modeling import ops
paddle.enable_static()
x = paddle.static.data(
......@@ -278,12 +280,12 @@ def iou_similarity(x, y, box_normalized=True, name=None):
Examples:
.. code-block:: python
import numpy as np
import paddle
from ppdet.modeling import ops
paddle.enable_static()
x = paddle.data(name='x', shape=[None, 4], dtype='float32')
y = paddle.data(name='y', shape=[None, 4], dtype='float32')
x = paddle.static.data(name='x', shape=[None, 4], dtype='float32')
y = paddle.static.data(name='y', shape=[None, 4], dtype='float32')
iou = ops.iou_similarity(x=x, y=y)
"""
......@@ -355,8 +357,8 @@ def collect_fpn_proposals(multi_rois,
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
from ppdet.modeling import ops
paddle.enable_static()
multi_rois = []
multi_scores = []
......@@ -367,7 +369,7 @@ def collect_fpn_proposals(multi_rois,
multi_scores.append(paddle.static.data(
name='score_'+str(i), shape=[None, 1], dtype='float32', lod_level=1))
fpn_rois = fluid.layers.collect_fpn_proposals(
fpn_rois = ops.collect_fpn_proposals(
multi_rois=multi_rois,
multi_scores=multi_scores,
min_level=2,
......@@ -475,12 +477,12 @@ def distribute_fpn_proposals(fpn_rois,
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
from ppdet.modeling import ops
paddle.enable_static()
fpn_rois = paddle.static.data(
name='data', shape=[None, 4], dtype='float32', lod_level=1)
multi_rois, restore_ind = fluid.layers.distribute_fpn_proposals(
multi_rois, restore_ind = ops.distribute_fpn_proposals(
fpn_rois=fpn_rois,
min_level=2,
max_level=5,
......@@ -621,6 +623,7 @@ def yolo_box(
.. code-block:: python
import paddle
from ppdet.modeling import ops
paddle.enable_static()
x = paddle.static.data(name='x', shape=[None, 255, 13, 13], dtype='float32')
......@@ -671,6 +674,411 @@ def yolo_box(
return boxes, scores
def prior_box(input,
image,
min_sizes,
max_sizes=None,
aspect_ratios=[1.],
variance=[0.1, 0.1, 0.2, 0.2],
flip=False,
clip=False,
steps=[0.0, 0.0],
offset=0.5,
min_max_aspect_ratios_order=False,
name=None):
"""
This op generates prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
Each position of the input produce N prior boxes, N is determined by
the count of min_sizes, max_sizes and aspect_ratios, The size of the
box is in range(min_size, max_size) interval, which is generated in
sequence according to the aspect_ratios.
Parameters:
input(Tensor): 4-D tensor(NCHW), the data type should be float32 or float64.
image(Tensor): 4-D tensor(NCHW), the input image data of PriorBoxOp,
the data type should be float32 or float64.
min_sizes(list|tuple|float): the min sizes of generated prior boxes.
max_sizes(list|tuple|None): the max sizes of generated prior boxes.
Default: None.
aspect_ratios(list|tuple|float): the aspect ratios of generated
prior boxes. Default: [1.].
variance(list|tuple): the variances to be encoded in prior boxes.
Default:[0.1, 0.1, 0.2, 0.2].
flip(bool): Whether to flip aspect ratios. Default:False.
clip(bool): Whether to clip out-of-boundary boxes. Default: False.
step(list|tuple): Prior boxes step across width and height, If
step[0] equals to 0.0 or step[1] equals to 0.0, the prior boxes step across
height or weight of the input will be automatically calculated.
Default: [0., 0.]
offset(float): Prior boxes center offset. Default: 0.5
min_max_aspect_ratios_order(bool): If set True, the output prior box is
in order of [min, max, aspect_ratios], which is consistent with
Caffe. Please note, this order affects the weights order of
convolution layer followed by and does not affect the final
detection results. Default: False.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tuple: A tuple with two Variable (boxes, variances)
boxes(Tensor): the output prior boxes of PriorBox.
4-D tensor, the layout is [H, W, num_priors, 4].
H is the height of input, W is the width of input,
num_priors is the total box count of each position of input.
variances(Tensor): the expanded variances of PriorBox.
4-D tensor, the layput is [H, W, num_priors, 4].
H is the height of input, W is the width of input
num_priors is the total box count of each position of input
Examples:
.. code-block:: python
import paddle
from ppdet.modeling import ops
paddle.enable_static()
input = paddle.static.data(name="input", shape=[None,3,6,9])
image = paddle.static.data(name="image", shape=[None,3,9,12])
box, var = ops.prior_box(
input=input,
image=image,
min_sizes=[100.],
clip=True,
flip=True)
"""
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
check_variable_and_dtype(
input, 'input', ['uint8', 'int8', 'float32', 'float64'], 'prior_box')
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(min_sizes):
min_sizes = [min_sizes]
if not _is_list_or_tuple_(aspect_ratios):
aspect_ratios = [aspect_ratios]
if not (_is_list_or_tuple_(steps) and len(steps) == 2):
raise ValueError('steps should be a list or tuple ',
'with length 2, (step_width, step_height).')
min_sizes = list(map(float, min_sizes))
aspect_ratios = list(map(float, aspect_ratios))
steps = list(map(float, steps))
cur_max_sizes = None
if max_sizes is not None and len(max_sizes) > 0 and max_sizes[0] > 0:
if not _is_list_or_tuple_(max_sizes):
max_sizes = [max_sizes]
cur_max_sizes = max_sizes
if in_dygraph_mode():
attrs = [
'min_sizes', min_sizes, 'aspect_ratios', aspect_ratios, 'variances',
variance, 'flip', flip, 'clip', clip, 'step_w', steps[0], 'step_h',
steps[1], 'offset', offset, 'min_max_aspect_ratios_order',
min_max_aspect_ratios_order
]
if cur_max_sizes is not None:
attrs.extend('max_sizes', max_sizes)
attrs = tuple(attrs)
box, var = core.ops.prior_box(input, image, *attrs)
return box, var
attrs = {
'min_sizes': min_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': steps[0],
'step_h': steps[1],
'offset': offset,
'min_max_aspect_ratios_order': min_max_aspect_ratios_order
}
if cur_max_sizes is not None:
attrs['max_sizes'] = cur_max_sizes
box = helper.create_variable_for_type_inference(dtype)
var = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input,
"Image": image},
outputs={"Boxes": box,
"Variances": var},
attrs=attrs, )
box.stop_gradient = True
var.stop_gradient = True
return box, var
def anchor_generator(input,
anchor_sizes=None,
aspect_ratios=None,
variance=[0.1, 0.1, 0.2, 0.2],
stride=None,
offset=0.5,
name=None):
"""
This op generate anchors for Faster RCNN algorithm.
Each position of the input produce N anchors, N =
size(anchor_sizes) * size(aspect_ratios). The order of generated anchors
is firstly aspect_ratios loop then anchor_sizes loop.
Args:
input(Tensor): 4-D Tensor with shape [N,C,H,W]. The input feature map.
anchor_sizes(float32|list|tuple, optional): The anchor sizes of generated
anchors, given in absolute pixels e.g. [64., 128., 256., 512.].
For instance, the anchor size of 64 means the area of this anchor
equals to 64**2. None by default.
aspect_ratios(float32|list|tuple, optional): The height / width ratios
of generated anchors, e.g. [0.5, 1.0, 2.0]. None by default.
variance(list|tuple, optional): The variances to be used in box
regression deltas. The data type is float32, [0.1, 0.1, 0.2, 0.2] by
default.
stride(list|tuple, optional): The anchors stride across width and height.
The data type is float32. e.g. [16.0, 16.0]. None by default.
offset(float32, optional): Prior boxes center offset. 0.5 by default.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and None
by default.
Returns:
Tuple:
Anchors(Tensor): The output anchors with a layout of [H, W, num_anchors, 4].
H is the height of input, W is the width of input,
num_anchors is the box count of each position.
Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized.
Variances(Tensor): The expanded variances of anchors
with a layout of [H, W, num_priors, 4].
H is the height of input, W is the width of input
num_anchors is the box count of each position.
Each variance is in (xcenter, ycenter, w, h) format.
Examples:
.. code-block:: python
import paddle
from ppdet.modeling import ops
paddle.enable_static()
conv1 = paddle.static.data(name='input', shape=[None, 48, 16, 16], dtype='float32')
anchor, var = ops.anchor_generator(
input=conv1,
anchor_sizes=[64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0],
variance=[0.1, 0.1, 0.2, 0.2],
stride=[16.0, 16.0],
offset=0.5)
"""
helper = LayerHelper("anchor_generator", **locals())
dtype = helper.input_dtype()
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(anchor_sizes):
anchor_sizes = [anchor_sizes]
if not _is_list_or_tuple_(aspect_ratios):
aspect_ratios = [aspect_ratios]
if not (_is_list_or_tuple_(stride) and len(stride) == 2):
raise ValueError('stride should be a list or tuple ',
'with length 2, (stride_width, stride_height).')
anchor_sizes = list(map(float, anchor_sizes))
aspect_ratios = list(map(float, aspect_ratios))
stride = list(map(float, stride))
if in_dygraph_mode():
attrs = ('anchor_sizes', anchor_sizes, 'aspect_ratios', aspect_ratios,
'variances', variance, 'stride', stride, 'offset', offset)
anchor, var = core.ops.anchor_generator(input, *attrs)
return anchor, var
attrs = {
'anchor_sizes': anchor_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'stride': stride,
'offset': offset
}
anchor = helper.create_variable_for_type_inference(dtype)
var = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="anchor_generator",
inputs={"Input": input},
outputs={"Anchors": anchor,
"Variances": var},
attrs=attrs, )
anchor.stop_gradient = True
var.stop_gradient = True
return anchor, var
def multiclass_nms(bboxes,
scores,
score_threshold,
nms_top_k,
keep_top_k,
nms_threshold=0.3,
normalized=True,
nms_eta=1.,
background_label=0,
return_index=False,
rois_num=None,
name=None):
"""
This operator is to do multi-class non maximum suppression (NMS) on
boxes and scores.
In the NMS step, this operator greedily selects a subset of detection bounding
boxes that have high scores larger than score_threshold, if providing this
threshold, then selects the largest nms_top_k confidences scores if nms_top_k
is larger than -1. Then this operator pruns away boxes that have high IOU
(intersection over union) overlap with already selected boxes by adaptive
threshold NMS based on parameters of nms_threshold and nms_eta.
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
per image if keep_top_k is larger than -1.
Args:
bboxes (Tensor): Two types of bboxes are supported:
1. (Tensor) A 3-D Tensor with shape
[N, M, 4 or 8 16 24 32] represents the
predicted locations of M bounding bboxes,
N is the batch size. Each bounding box has four
coordinate values and the layout is
[xmin, ymin, xmax, ymax], when box size equals to 4.
2. (LoDTensor) A 3-D Tensor with shape [M, C, 4]
M is the number of bounding boxes, C is the
class number
scores (Tensor): Two types of scores are supported:
1. (Tensor) A 3-D Tensor with shape [N, C, M]
represents the predicted confidence predictions.
N is the batch size, C is the class number, M is
number of bounding boxes. For each category there
are total M scores which corresponding M bounding
boxes. Please note, M is equal to the 2nd dimension
of BBoxes.
2. (LoDTensor) A 2-D LoDTensor with shape [M, C].
M is the number of bbox, C is the class number.
In this case, input BBoxes should be the second
case with shape [M, C, 4].
background_label (int): The index of background label, the background
label will be ignored. If set to -1, then all
categories will be considered. Default: 0
score_threshold (float): Threshold to filter out bounding boxes with
low confidence score. If not provided,
consider all boxes.
nms_top_k (int): Maximum number of detections to be kept according to
the confidences after the filtering detections based
on score_threshold.
nms_threshold (float): The threshold to be used in NMS. Default: 0.3
nms_eta (float): The threshold to be used in NMS. Default: 1.0
keep_top_k (int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step.
normalized (bool): Whether detections are normalized. Default: True
return_index(bool): Whether return selected index. Default: False
rois_num(Tensor): 1-D Tensor contains the number of RoIs in each image.
The shape is [B] and data type is int32. B is the number of images.
If it is not None then return a list of 1-D Tensor. Each element
is the output RoIs' number of each image on the corresponding level
and the shape is [B]. None by default.
name(str): Name of the multiclass nms op. Default: None.
Returns:
A tuple with two Variables: (Out, Index) if return_index is True,
otherwise, a tuple with one Variable(Out) is returned.
Out: A 2-D LoDTensor with shape [No, 6] represents the detections.
Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax]
or A 2-D LoDTensor with shape [No, 10] represents the detections.
Each row has 10 values: [label, confidence, x1, y1, x2, y2, x3, y3,
x4, y4]. No is the total number of detections.
If all images have not detected results, all elements in LoD will be
0, and output tensor is empty (None).
Index: Only return when return_index is True. A 2-D LoDTensor with
shape [No, 1] represents the selected index which type is Integer.
The index is the absolute value cross batches. No is the same number
as Out. If the index is used to gather other attribute such as age,
one needs to reshape the input(N, M, 1) to (N * M, 1) as first, where
N is the batch size and M is the number of boxes.
Examples:
.. code-block:: python
import paddle
from ppdet.modeling import ops
boxes = paddle.static.data(name='bboxes', shape=[81, 4],
dtype='float32', lod_level=1)
scores = paddle.static.data(name='scores', shape=[81],
dtype='float32', lod_level=1)
out, index = ops.multiclass_nms(bboxes=boxes,
scores=scores,
background_label=0,
score_threshold=0.5,
nms_top_k=400,
nms_threshold=0.3,
keep_top_k=200,
normalized=False,
return_index=True)
"""
helper = LayerHelper('multiclass_nms3', **locals())
if in_dygraph_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
attrs = ('background_label', background_label, 'score_threshold',
score_threshold, 'nms_top_k', nms_top_k, 'nms_threshold',
nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta,
'normalized', normalized)
output, index, nms_rois_num = core.ops.multiclass_nms3(bboxes, scores,
rois_num, *attrs)
if return_index:
return output, index, nms_rois_num
else:
return output, nms_rois_num
output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
index = helper.create_variable_for_type_inference(dtype='int')
inputs = {'BBoxes': bboxes, 'Scores': scores}
outputs = {'Out': output, 'Index': index}
if rois_num is not None:
inputs['RoisNum'] = rois_num
nms_rois_num = helper.create_variable_for_type_inference(dtype='int32')
outputs['NmsRoisNum'] = nms_rois_num
helper.append_op(
type="multiclass_nms3",
inputs=inputs,
attrs={
'background_label': background_label,
'score_threshold': score_threshold,
'nms_top_k': nms_top_k,
'nms_threshold': nms_threshold,
'keep_top_k': keep_top_k,
'nms_eta': nms_eta,
'normalized': normalized
},
outputs=outputs)
output.stop_gradient = True
index.stop_gradient = True
if return_index and rois_num is not None:
return output, index, nms_rois_num
elif return_index and rois_num is None:
return output, index
elif not return_index and rois_num is not None:
return output, nms_rois_num
return output
def matrix_nms(bboxes,
scores,
score_threshold,
......@@ -686,16 +1094,13 @@ def matrix_nms(bboxes,
name=None):
"""
**Matrix NMS**
This operator does matrix non maximum suppression (NMS).
First selects a subset of candidate bounding boxes that have higher scores
than score_threshold (if provided), then the top k candidate is selected if
nms_top_k is larger than -1. Score of the remaining candidate are then
decayed according to the Matrix NMS scheme.
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
per image if keep_top_k is larger than -1.
Args:
bboxes (Tensor): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes,
......@@ -728,30 +1133,22 @@ def matrix_nms(bboxes,
return_index(bool): Whether return selected index. Default: False
return_rois_num(bool): whether return rois_num. Default: True
name(str): Name of the matrix nms op. Default: None.
Returns:
A tuple with three Tensor: (Out, Index, RoisNum) if return_index is True,
otherwise, a tuple with two Tensor (Out, RoisNum) is returned.
Out (Tensor): A 2-D Tensor with shape [No, 6] containing the
detection results.
Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax]
(After version 1.3, when no boxes detected, the lod is changed
from {0} to {1})
Index (Tensor): A 2-D Tensor with shape [No, 1] containing the
selected indices, which are absolute values cross batches.
rois_num (Tensor): A 1-D Tensor with shape [N] containing
the number of detected boxes in each image.
Examples:
.. code-block:: python
import paddle
from ppdet.modeling import ops
boxes = paddle.static.data(name='bboxes', shape=[None,81, 4],
dtype='float32', lod_level=1)
scores = paddle.static.data(name='scores', shape=[None,81],
......@@ -759,7 +1156,6 @@ def matrix_nms(bboxes,
out = ops.matrix_nms(bboxes=boxes, scores=scores, background_label=0,
score_threshold=0.5, post_threshold=0.1,
nms_top_k=400, keep_top_k=200, normalized=False)
"""
check_variable_and_dtype(bboxes, 'BBoxes', ['float32', 'float64'],
'matrix_nms')
......
......@@ -383,7 +383,7 @@ class TestIoUSimilarity(LayerTest):
self.assertTrue(np.array_equal(iou_np, iou_dy_np))
class TestYOLO_Box(LayerTest):
class TestYOLOBox(LayerTest):
def test_yolo_box(self):
# x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2
......@@ -414,11 +414,6 @@ class TestYOLO_Box(LayerTest):
feed={
'x': np_x,
'origin_shape': np_origin_shape,
'anchors': [10, 13, 30, 13],
'class_num': 10,
'conf_thresh': 0.01,
'downsample_ratio': 32,
'scale_x_y': 1.0,
},
fetch_list=[boxes, scores],
with_lod=False)
......@@ -439,8 +434,8 @@ class TestYOLO_Box(LayerTest):
boxes_dy_np = boxes_dy.numpy()
scores_dy_np = scores_dy.numpy()
self.assertTrue(np.array_equal(boxes_np, boxes_dy_np))
self.assertTrue(np.array_equal(scores_np, scores_dy_np))
self.assertTrue(np.array_equal(boxes_np, boxes_dy_np))
self.assertTrue(np.array_equal(scores_np, scores_dy_np))
def test_yolo_box_error(self):
paddle.enable_static()
......@@ -463,6 +458,185 @@ class TestYOLO_Box(LayerTest):
scale_x_y=1.2)
class TestPriorBox(LayerTest):
def test_prior_box(self):
input_np = np.random.rand(2, 10, 32, 32).astype('float32')
image_np = np.random.rand(2, 10, 40, 40).astype('float32')
min_sizes = [2, 4]
with self.static_graph():
input = paddle.static.data(
name='input', shape=[2, 10, 32, 32], dtype='float32')
image = paddle.static.data(
name='image', shape=[2, 10, 40, 40], dtype='float32')
box, var = ops.prior_box(
input=input,
image=image,
min_sizes=min_sizes,
clip=True,
flip=True)
box_np, var_np = self.get_static_graph_result(
feed={
'input': input_np,
'image': image_np,
},
fetch_list=[box, var],
with_lod=False)
with self.dynamic_graph():
inputs_dy = base.to_variable(input_np)
image_dy = base.to_variable(image_np)
box_dy, var_dy = ops.prior_box(
input=inputs_dy,
image=image_dy,
min_sizes=min_sizes,
clip=True,
flip=True)
box_dy_np = box_dy.numpy()
var_dy_np = var_dy.numpy()
self.assertTrue(np.array_equal(box_np, box_dy_np))
self.assertTrue(np.array_equal(var_np, var_dy_np))
def test_prior_box_error(self):
program = Program()
paddle.enable_static()
with program_guard(program):
input = paddle.static.data(
name='input', shape=[2, 10, 32, 32], dtype='int32')
image = paddle.static.data(
name='image', shape=[2, 10, 40, 40], dtype='int32')
self.assertRaises(
TypeError,
ops.prior_box,
input=input,
image=image,
min_sizes=[2, 4],
clip=True,
flip=True)
class TestAnchorGenerator(LayerTest):
def test_anchor_generator(self):
b, c, h, w = 2, 48, 16, 16
input_np = np.random.rand(2, 48, 16, 16).astype('float32')
paddle.enable_static()
with self.static_graph():
input = paddle.static.data(
name='input', shape=[b, c, h, w], dtype='float32')
anchor, var = ops.anchor_generator(
input=input,
anchor_sizes=[64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0],
variance=[0.1, 0.1, 0.2, 0.2],
stride=[16.0, 16.0],
offset=0.5)
anchor_np, var_np = self.get_static_graph_result(
feed={'input': input_np, },
fetch_list=[anchor, var],
with_lod=False)
with self.dynamic_graph():
inputs_dy = base.to_variable(input_np)
anchor_dy, var_dy = ops.anchor_generator(
input=inputs_dy,
anchor_sizes=[64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0],
variance=[0.1, 0.1, 0.2, 0.2],
stride=[16.0, 16.0],
offset=0.5)
anchor_dy_np = anchor_dy.numpy()
var_dy_np = var_dy.numpy()
self.assertTrue(np.array_equal(anchor_np, anchor_dy_np))
self.assertTrue(np.array_equal(var_np, var_dy_np))
class TestMulticlassNms(LayerTest):
def test_multiclass_nms(self):
boxes_np = np.random.rand(81, 4).astype('float32')
scores_np = np.random.rand(81).astype('float32')
rois_num_np = np.array([40, 41]).astype('int32')
with self.static_graph():
boxes = paddle.static.data(
name='bboxes', shape=[81, 4], dtype='float32', lod_level=1)
scores = paddle.static.data(
name='scores', shape=[81], dtype='float32', lod_level=1)
rois_num = paddle.static.data(
name='rois_num', shape=[40, 41], dtype='int32')
output = ops.multiclass_nms(
bboxes=boxes,
scores=scores,
background_label=0,
score_threshold=0.5,
nms_top_k=400,
nms_threshold=0.3,
keep_top_k=200,
normalized=False,
return_index=True,
rois_num=rois_num)
out_np, index_np, nms_rois_num_np = self.get_static_graph_result(
feed={
'bboxes': boxes_np,
'scores': scores_np,
'rois_num': rois_num_np
},
fetch_list=output,
with_lod=False)
with self.dynamic_graph():
boxes_dy = base.to_variable(boxes_np)
scores_dy = base.to_variable(scores_np)
rois_num_dy = base.to_variable(rois_num_np)
out_dy, index_dy, nms_rois_num_dy = ops.multiclass_nms(
bboxes=boxes_dy,
scores=scores_dy,
background_label=0,
score_threshold=0.5,
nms_top_k=400,
nms_threshold=0.3,
keep_top_k=200,
normalized=False,
return_index=True,
rois_num=rois_num_dy)
out_dy_np = out_dy.numpy()
index_dy_np = index_dy.numpy()
nms_rois_num_dy_np = nms_rois_num_dy.numpy()
self.assertTrue(np.array_equal(out_np, out_dy_np))
self.assertTrue(np.array_equal(index_np, index_dy_np))
self.assertTrue(np.array_equal(nms_rois_num_np, nms_rois_num_dy_np))
def test_multiclass_nms_error(self):
program = Program()
paddle.enable_static()
with program_guard(program):
boxes = paddle.static.data(
name='bboxes', shape=[81, 4], dtype='float32', lod_level=1)
scores = paddle.static.data(
name='scores', shape=[81], dtype='float32', lod_level=1)
rois_num = paddle.static.data(
name='rois_num', shape=[40, 41], dtype='int32')
self.assertRaises(
TypeError,
ops.multiclass_nms,
boxes=boxes,
scores=scores,
background_label=0,
score_threshold=0.5,
nms_top_k=400,
nms_threshold=0.3,
keep_top_k=200,
normalized=False,
return_index=True,
rois_num=rois_num)
class TestMatrixNMS(LayerTest):
def test_matrix_nms(self):
N, M, C = 7, 1200, 21
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册