未验证 提交 2e299ad0 编写于 作者: F Feng Ni 提交者: GitHub

add prior_box and box_coder for paddle.vision.ops (#47282)

上级 3a0690e6
......@@ -996,63 +996,15 @@ def box_coder(
box_normalized=False,
axis=1)
"""
check_variable_and_dtype(
prior_box, 'prior_box', ['float32', 'float64'], 'box_coder'
)
check_variable_and_dtype(
target_box, 'target_box', ['float32', 'float64'], 'box_coder'
)
if in_dygraph_mode():
if isinstance(prior_box_var, Variable):
box_coder_op = _C_ops.box_coder(
prior_box,
prior_box_var,
target_box,
code_type,
box_normalized,
axis,
[],
)
elif isinstance(prior_box_var, list):
box_coder_op = _C_ops.box_coder(
prior_box,
None,
target_box,
code_type,
box_normalized,
axis,
prior_box_var,
)
else:
raise TypeError(
"Input variance of box_coder must be Variable or lisz"
)
return box_coder_op
helper = LayerHelper("box_coder", **locals())
output_box = helper.create_variable_for_type_inference(
dtype=prior_box.dtype
)
inputs = {"PriorBox": prior_box, "TargetBox": target_box}
attrs = {
"code_type": code_type,
"box_normalized": box_normalized,
"axis": axis,
}
if isinstance(prior_box_var, Variable):
inputs['PriorBoxVar'] = prior_box_var
elif isinstance(prior_box_var, list):
attrs['variance'] = prior_box_var
else:
raise TypeError("Input variance of box_coder must be Variable or lisz")
helper.append_op(
type="box_coder",
inputs=inputs,
attrs=attrs,
outputs={"OutputBox": output_box},
return paddle.vision.ops.box_coder(
prior_box=prior_box,
prior_box_var=prior_box_var,
target_box=target_box,
code_type=code_type,
box_normalized=box_normalized,
axis=axis,
name=name,
)
return output_box
@templatedoc()
......@@ -1974,8 +1926,8 @@ def prior_box(
#declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
paddle.enable_static()
import paddle
paddle.enable_static()
input = fluid.data(name="input", shape=[None,3,6,9])
image = fluid.data(name="image", shape=[None,3,9,12])
box, var = fluid.layers.prior_box(
......@@ -2021,75 +1973,20 @@ def prior_box(
# [6L, 9L, 1L, 4L]
"""
if in_dygraph_mode():
step_w, step_h = steps
if max_sizes == None:
max_sizes = []
return _C_ops.prior_box(
input,
image,
min_sizes,
aspect_ratios,
variance,
max_sizes,
flip,
clip,
step_w,
step_h,
offset,
min_max_aspect_ratios_order,
)
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
check_variable_and_dtype(
input, 'input', ['uint8', 'int8', 'float32', 'float64'], 'prior_box'
)
def _is_list_or_tuple_(data):
return isinstance(data, list) or isinstance(data, tuple)
if not _is_list_or_tuple_(min_sizes):
min_sizes = [min_sizes]
if not _is_list_or_tuple_(aspect_ratios):
aspect_ratios = [aspect_ratios]
if not (_is_list_or_tuple_(steps) and len(steps) == 2):
raise ValueError(
'steps should be a list or tuple ',
'with length 2, (step_width, step_height).',
)
min_sizes = list(map(float, min_sizes))
aspect_ratios = list(map(float, aspect_ratios))
steps = list(map(float, steps))
attrs = {
'min_sizes': min_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': steps[0],
'step_h': steps[1],
'offset': offset,
'min_max_aspect_ratios_order': min_max_aspect_ratios_order,
}
if max_sizes is not None and len(max_sizes) > 0 and max_sizes[0] > 0:
if not _is_list_or_tuple_(max_sizes):
max_sizes = [max_sizes]
attrs['max_sizes'] = max_sizes
box = helper.create_variable_for_type_inference(dtype)
var = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input, "Image": image},
outputs={"Boxes": box, "Variances": var},
attrs=attrs,
return paddle.vision.ops.prior_box(
input=input,
image=image,
min_sizes=min_sizes,
max_sizes=max_sizes,
aspect_ratios=aspect_ratios,
variance=variance,
flip=flip,
clip=clip,
steps=steps,
offset=offset,
min_max_aspect_ratios_order=min_max_aspect_ratios_order,
name=name,
)
box.stop_gradient = True
var.stop_gradient = True
return box, var
def density_prior_box(
......
......@@ -319,5 +319,62 @@ class TestBoxCoderOpWithVarianceDygraphAPI(unittest.TestCase):
run(place)
class TestBoxCoderAPI(unittest.TestCase):
def setUp(self):
np.random.seed(678)
self.prior_box_np = np.random.random((80, 4)).astype('float32')
self.prior_box_var_np = np.random.random((80, 4)).astype('float32')
self.target_box_np = np.random.random((20, 80, 4)).astype('float32')
def test_dygraph_with_static(self):
paddle.enable_static()
prior_box = paddle.static.data(
name='prior_box', shape=[80, 4], dtype='float32'
)
prior_box_var = paddle.static.data(
name='prior_box_var', shape=[80, 4], dtype='float32'
)
target_box = paddle.static.data(
name='target_box', shape=[20, 80, 4], dtype='float32'
)
boxes = paddle.vision.ops.box_coder(
prior_box=prior_box,
prior_box_var=prior_box_var,
target_box=target_box,
code_type="decode_center_size",
box_normalized=False,
)
exe = paddle.static.Executor()
boxes_np = exe.run(
paddle.static.default_main_program(),
feed={
'prior_box': self.prior_box_np,
'prior_box_var': self.prior_box_var_np,
'target_box': self.target_box_np,
},
fetch_list=[boxes],
)
paddle.disable_static()
prior_box_dy = paddle.to_tensor(self.prior_box_np)
prior_box_var_dy = paddle.to_tensor(self.prior_box_var_np)
target_box_dy = paddle.to_tensor(self.target_box_np)
boxes_dy = paddle.vision.ops.box_coder(
prior_box=prior_box_dy,
prior_box_var=prior_box_var_dy,
target_box=target_box_dy,
code_type="decode_center_size",
box_normalized=False,
)
boxes_dy_np = boxes_dy.numpy()
np.testing.assert_allclose(boxes_np[0], boxes_dy_np)
paddle.enable_static()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -109,9 +109,6 @@ class TestPriorBoxOp(OpTest):
self.flip = True
self.set_min_max_aspect_ratios_order()
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
self.aspect_ratios = np.array(
self.aspect_ratios, dtype=np.float64
).flatten()
self.variances = [0.1, 0.1, 0.2, 0.2]
self.variances = np.array(self.variances, dtype=np.float64).flatten()
......@@ -225,6 +222,59 @@ class TestPriorBoxOpWithSpecifiedOutOrder(TestPriorBoxOp):
self.min_max_aspect_ratios_order = True
class TestPriorBoxAPI(unittest.TestCase):
def setUp(self):
np.random.seed(678)
self.input_np = np.random.rand(2, 10, 32, 32).astype('float32')
self.image_np = np.random.rand(2, 10, 40, 40).astype('float32')
self.min_sizes = [2.0, 4.0]
def test_dygraph_with_static(self):
paddle.enable_static()
input = paddle.static.data(
name='input', shape=[2, 10, 32, 32], dtype='float32'
)
image = paddle.static.data(
name='image', shape=[2, 10, 40, 40], dtype='float32'
)
box, var = paddle.vision.ops.prior_box(
input=input,
image=image,
min_sizes=self.min_sizes,
clip=True,
flip=True,
)
exe = paddle.static.Executor()
box_np, var_np = exe.run(
paddle.static.default_main_program(),
feed={
'input': self.input_np,
'image': self.image_np,
},
fetch_list=[box, var],
)
paddle.disable_static()
inputs_dy = paddle.to_tensor(self.input_np)
image_dy = paddle.to_tensor(self.image_np)
box_dy, var_dy = paddle.vision.ops.prior_box(
input=inputs_dy,
image=image_dy,
min_sizes=self.min_sizes,
clip=True,
flip=True,
)
box_dy_np = box_dy.numpy()
var_dy_np = var_dy.numpy()
np.testing.assert_allclose(box_np, box_dy_np)
np.testing.assert_allclose(var_np, var_dy_np)
paddle.enable_static()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -19,6 +19,7 @@ from ..fluid.layers import nn, utils
from ..nn import Layer, Conv2D, Sequential, ReLU, BatchNorm2D
from ..fluid.initializer import Normal
from ..fluid.framework import (
Variable,
_non_static_mode,
in_dygraph_mode,
_in_legacy_dygraph,
......@@ -29,6 +30,8 @@ from ..framework import _current_expected_place
__all__ = [ # noqa
'yolo_loss',
'yolo_box',
'prior_box',
'box_coder',
'deform_conv2d',
'DeformConv2D',
'distribute_fpn_proposals',
......@@ -479,6 +482,379 @@ def yolo_box(
return boxes, scores
def prior_box(
input,
image,
min_sizes,
max_sizes=None,
aspect_ratios=[1.0],
variance=[0.1, 0.1, 0.2, 0.2],
flip=False,
clip=False,
steps=[0.0, 0.0],
offset=0.5,
min_max_aspect_ratios_order=False,
name=None,
):
r"""
This op generates prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
Each position of the input produce N prior boxes, N is determined by
the count of min_sizes, max_sizes and aspect_ratios, The size of the
box is in range(min_size, max_size) interval, which is generated in
sequence according to the aspect_ratios.
Args:
input (Tensor): 4-D tensor(NCHW), the data type should be float32 or float64.
image (Tensor): 4-D tensor(NCHW), the input image data of PriorBoxOp,
the data type should be float32 or float64.
min_sizes (list|tuple|float): the min sizes of generated prior boxes.
max_sizes (list|tuple|None, optional): the max sizes of generated prior boxes.
Default: None, means [] and will not be used.
aspect_ratios (list|tuple|float, optional): the aspect ratios of generated
prior boxes. Default: [1.0].
variance (list|tuple, optional): the variances to be encoded in prior boxes.
Default:[0.1, 0.1, 0.2, 0.2].
flip (bool): Whether to flip aspect ratios. Default:False.
clip (bool): Whether to clip out-of-boundary boxes. Default: False.
steps (list|tuple, optional): Prior boxes steps across width and height, If
steps[0] equals to 0.0 or steps[1] equals to 0.0, the prior boxes steps across
height or weight of the input will be automatically calculated.
Default: [0., 0.]
offset (float, optional)): Prior boxes center offset. Default: 0.5
min_max_aspect_ratios_order (bool, optional): If set True, the output prior box is
in order of [min, max, aspect_ratios], which is consistent with
Caffe. Please note, this order affects the weights order of
convolution layer followed by and does not affect the final
detection results. Default: False.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor: the output prior boxes and the expanded variances of PriorBox.
The prior boxes is a 4-D tensor, the layout is [H, W, num_priors, 4],
num_priors is the total box count of each position of input.
The expanded variances is a 4-D tensor, same shape as the prior boxes.
Examples:
.. code-block:: python
import paddle
input = paddle.rand((1, 3, 6, 9), dtype=paddle.float32)
image = paddle.rand((1, 3, 9, 12), dtype=paddle.float32)
box, var = paddle.vision.ops.prior_box(
input=input,
image=image,
min_sizes=[2.0, 4.0],
clip=True,
flip=True)
"""
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
check_variable_and_dtype(
input, 'input', ['uint8', 'int8', 'float32', 'float64'], 'prior_box'
)
def _is_list_or_tuple_(data):
return isinstance(data, list) or isinstance(data, tuple)
if not _is_list_or_tuple_(min_sizes):
min_sizes = [min_sizes]
if not _is_list_or_tuple_(aspect_ratios):
aspect_ratios = [aspect_ratios]
if not _is_list_or_tuple_(steps):
steps = [steps]
if not len(steps) == 2:
raise ValueError('steps should be (step_w, step_h)')
min_sizes = list(map(float, min_sizes))
aspect_ratios = list(map(float, aspect_ratios))
steps = list(map(float, steps))
cur_max_sizes = None
if max_sizes is not None and len(max_sizes) > 0 and max_sizes[0] > 0:
if not _is_list_or_tuple_(max_sizes):
max_sizes = [max_sizes]
cur_max_sizes = max_sizes
if in_dygraph_mode():
step_w, step_h = steps
if max_sizes == None:
max_sizes = []
box, var = _C_ops.prior_box(
input,
image,
min_sizes,
aspect_ratios,
variance,
max_sizes,
flip,
clip,
step_w,
step_h,
offset,
min_max_aspect_ratios_order,
)
return box, var
if _in_legacy_dygraph():
attrs = (
'min_sizes',
min_sizes,
'aspect_ratios',
aspect_ratios,
'variances',
variance,
'flip',
flip,
'clip',
clip,
'step_w',
steps[0],
'step_h',
steps[1],
'offset',
offset,
'min_max_aspect_ratios_order',
min_max_aspect_ratios_order,
)
if cur_max_sizes is not None:
attrs += ('max_sizes', cur_max_sizes)
box, var = _legacy_C_ops.prior_box(input, image, *attrs)
return box, var
else:
attrs = {
'min_sizes': min_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': steps[0],
'step_h': steps[1],
'offset': offset,
'min_max_aspect_ratios_order': min_max_aspect_ratios_order,
}
if cur_max_sizes is not None:
attrs['max_sizes'] = cur_max_sizes
box = helper.create_variable_for_type_inference(dtype)
var = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input, "Image": image},
outputs={"Boxes": box, "Variances": var},
attrs=attrs,
)
box.stop_gradient = True
var.stop_gradient = True
return box, var
def box_coder(
prior_box,
prior_box_var,
target_box,
code_type="encode_center_size",
box_normalized=True,
axis=0,
name=None,
):
r"""
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
.. math::
ox &= (tx - px) / pw / pxv
oy &= (ty - py) / ph / pyv
ow &= log(abs(tw / pw)) / pwv
oh &= log(abs(th / ph)) / phv
The Decoding schema described below:
.. math::
ox &= (pw * pxv * tx * + px) - tw / 2
oy &= (ph * pyv * ty * + py) - th / 2
ow &= exp(pwv * tw) * pw + tw / 2
oh &= exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates,
width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote
the priorbox's (anchor) center coordinates, width and height. `pxv`,
`pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`,
`ow`, `oh` denote the encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target
box has shape [N, M, 4], and the shape of prior box can be [N, 4] or
[M, 4]. Then prior box will broadcast to target box along the
assigned axis.
Args:
prior_box (Tensor): Box list prior_box is a 2-D Tensor with shape
[M, 4] holds M boxes and data type is float32 or float64. Each box
is represented as [xmin, ymin, xmax, ymax], [xmin, ymin] is the
left top coordinate of the anchor box, if the input is image feature
map, they are close to the origin of the coordinate system.
[xmax, ymax] is the right bottom coordinate of the anchor box.
prior_box_var (List|Tensor|None): prior_box_var supports three types
of input. One is Tensor with shape [M, 4] which holds M group and
data type is float32 or float64. The second is list consist of
4 elements shared by all boxes and data type is float32 or float64.
Other is None and not involved in calculation.
target_box (Tensor): This input can be a 2-D LoDTensor with shape
[N, 4] when code_type is 'encode_center_size'. This input also can
be a 3-D Tensor with shape [N, M, 4] when code_type is
'decode_center_size'. Each box is represented as
[xmin, ymin, xmax, ymax]. The data type is float32 or float64.
code_type (str, optional): The code type used with the target box. It can be
`encode_center_size` or `decode_center_size`. `encode_center_size`
by default.
box_normalized (bool, optional): Whether treat the priorbox as a normalized box.
Set true by default.
axis (int, optional): Which axis in PriorBox to broadcast for box decode,
for example, if axis is 0 and TargetBox has shape [N, M, 4] and
PriorBox has shape [M, 4], then PriorBox will broadcast to [N, M, 4]
for decoding. It is only valid when code type is
`decode_center_size`. Set 0 by default.
name (str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: output boxes, when code_type is 'encode_center_size', the
output tensor of box_coder_op with shape [N, M, 4] representing the
result of N target boxes encoded with M Prior boxes and variances.
When code_type is 'decode_center_size', N represents the batch size
and M represents the number of decoded boxes.
Examples:
.. code-block:: python
import paddle
# For encode
prior_box_encode = paddle.rand((80, 4), dtype=paddle.float32)
prior_box_var_encode = paddle.rand((80, 4), dtype=paddle.float32)
target_box_encode = paddle.rand((20, 4), dtype=paddle.float32)
output_encode = paddle.vision.ops.box_coder(
prior_box=prior_box_encode,
prior_box_var=prior_box_var_encode,
target_box=target_box_encode,
code_type="encode_center_size")
# For decode
prior_box_decode = paddle.rand((80, 4), dtype=paddle.float32)
prior_box_var_decode = paddle.rand((80, 4), dtype=paddle.float32)
target_box_decode = paddle.rand((20, 80, 4), dtype=paddle.float32)
output_decode = paddle.vision.ops.box_coder(
prior_box=prior_box_decode,
prior_box_var=prior_box_var_decode,
target_box=target_box_decode,
code_type="decode_center_size",
box_normalized=False)
"""
check_variable_and_dtype(
prior_box, 'prior_box', ['float32', 'float64'], 'box_coder'
)
check_variable_and_dtype(
target_box, 'target_box', ['float32', 'float64'], 'box_coder'
)
if in_dygraph_mode():
if isinstance(prior_box_var, Variable):
output_box = _C_ops.box_coder(
prior_box,
prior_box_var,
target_box,
code_type,
box_normalized,
axis,
[],
)
elif isinstance(prior_box_var, list):
output_box = _C_ops.box_coder(
prior_box,
None,
target_box,
code_type,
box_normalized,
axis,
prior_box_var,
)
else:
raise TypeError("Input prior_box_var must be Variable or list")
return output_box
if _in_legacy_dygraph():
if isinstance(prior_box_var, Variable):
output_box = _legacy_C_ops.box_coder(
prior_box,
prior_box_var,
target_box,
"code_type",
code_type,
"box_normalized",
box_normalized,
"axis",
axis,
)
elif isinstance(prior_box_var, list):
output_box = _legacy_C_ops.box_coder(
prior_box,
None,
target_box,
"code_type",
code_type,
"box_normalized",
box_normalized,
"axis",
axis,
"variance",
prior_box_var,
)
else:
raise TypeError("Input prior_box_var must be Variable or list")
return output_box
else:
helper = LayerHelper("box_coder", **locals())
output_box = helper.create_variable_for_type_inference(
dtype=prior_box.dtype
)
inputs = {"PriorBox": prior_box, "TargetBox": target_box}
attrs = {
"code_type": code_type,
"box_normalized": box_normalized,
"axis": axis,
}
if isinstance(prior_box_var, Variable):
inputs['PriorBoxVar'] = prior_box_var
elif isinstance(prior_box_var, list):
attrs['variance'] = prior_box_var
else:
raise TypeError("Input prior_box_var must be Variable or list")
helper.append_op(
type="box_coder",
inputs=inputs,
attrs=attrs,
outputs={"OutputBox": output_box},
)
return output_box
def deform_conv2d(
x,
offset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册