未验证 提交 5b6767ac 编写于 作者: Z zqw_1997 提交者: GitHub

[fluid remove]: remove paddle.fluid.layers.box_coder and...

[fluid remove]: remove paddle.fluid.layers.box_coder and paddle.fluid.layers.polygon_box_transform (#48896)

* remove fluid_box_coder and polygon_box_transform

* code check
上级 8bf7f638
......@@ -47,8 +47,6 @@ __all__ = [
'generate_proposal_labels',
'generate_proposals',
'generate_mask_labels',
'box_coder',
'polygon_box_transform',
'box_clip',
'multiclass_nms',
'locality_aware_nms',
......@@ -60,177 +58,6 @@ __all__ = [
]
@templatedoc()
def box_coder(
prior_box,
prior_box_var,
target_box,
code_type="encode_center_size",
box_normalized=True,
name=None,
axis=0,
):
r"""
**Box Coder Layer**
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
.. math::
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = \log(\abs(tw / pw)) / pwv
oh = \log(\abs(th / ph)) / phv
The Decoding schema described below:
.. math::
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = \exp(pwv * tw) * pw + tw / 2
oh = \exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates,
width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote
the priorbox's (anchor) center coordinates, width and height. `pxv`,
`pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`,
`ow`, `oh` denote the encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target
box has shape [N, M, 4], and the shape of prior box can be [N, 4] or
[M, 4]. Then prior box will broadcast to target box along the
assigned axis.
Args:
prior_box(Variable): Box list prior_box is a 2-D Tensor with shape
[M, 4] holds M boxes and data type is float32 or float64. Each box
is represented as [xmin, ymin, xmax, ymax], [xmin, ymin] is the
left top coordinate of the anchor box, if the input is image feature
map, they are close to the origin of the coordinate system.
[xmax, ymax] is the right bottom coordinate of the anchor box.
prior_box_var(List|Variable|None): prior_box_var supports three types
of input. One is variable with shape [M, 4] which holds M group and
data type is float32 or float64. The second is list consist of
4 elements shared by all boxes and data type is float32 or float64.
Other is None and not involved in calculation.
target_box(Variable): This input can be a 2-D LoDTensor with shape
[N, 4] when code_type is 'encode_center_size'. This input also can
be a 3-D Tensor with shape [N, M, 4] when code_type is
'decode_center_size'. Each box is represented as
[xmin, ymin, xmax, ymax]. The data type is float32 or float64.
This tensor can contain LoD information to represent a batch of inputs.
code_type(str): The code type used with the target box. It can be
`encode_center_size` or `decode_center_size`. `encode_center_size`
by default.
box_normalized(bool): Whether treat the priorbox as a normalized box.
Set true by default.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
axis(int): Which axis in PriorBox to broadcast for box decode,
for example, if axis is 0 and TargetBox has shape [N, M, 4] and
PriorBox has shape [M, 4], then PriorBox will broadcast to [N, M, 4]
for decoding. It is only valid when code type is
`decode_center_size`. Set 0 by default.
Returns:
Variable:
output_box(Variable): When code_type is 'encode_center_size', the
output tensor of box_coder_op with shape [N, M, 4] representing the
result of N target boxes encoded with M Prior boxes and variances.
When code_type is 'decode_center_size', N represents the batch size
and M represents the number of decoded boxes.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
# For encode
prior_box_encode = fluid.data(name='prior_box_encode',
shape=[512, 4],
dtype='float32')
target_box_encode = fluid.data(name='target_box_encode',
shape=[81, 4],
dtype='float32')
output_encode = fluid.layers.box_coder(prior_box=prior_box_encode,
prior_box_var=[0.1,0.1,0.2,0.2],
target_box=target_box_encode,
code_type="encode_center_size")
# For decode
prior_box_decode = fluid.data(name='prior_box_decode',
shape=[512, 4],
dtype='float32')
target_box_decode = fluid.data(name='target_box_decode',
shape=[512, 81, 4],
dtype='float32')
output_decode = fluid.layers.box_coder(prior_box=prior_box_decode,
prior_box_var=[0.1,0.1,0.2,0.2],
target_box=target_box_decode,
code_type="decode_center_size",
box_normalized=False,
axis=1)
"""
return paddle.vision.ops.box_coder(
prior_box=prior_box,
prior_box_var=prior_box_var,
target_box=target_box,
code_type=code_type,
box_normalized=box_normalized,
axis=axis,
name=name,
)
@templatedoc()
def polygon_box_transform(input, name=None):
"""
${comment}
Args:
input(Variable): The input with shape [batch_size, geometry_channels, height, width].
A Tensor with type float32, float64.
name(str, Optional): For details, please refer to :ref:`api_guide_Name`.
Generally, no setting is required. Default: None.
Returns:
Variable: The output with the same shape as input. A Tensor with type float32, float64.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.data(name='input', shape=[4, 10, 5, 5], dtype='float32')
out = fluid.layers.polygon_box_transform(input)
"""
check_variable_and_dtype(
input, "input", ['float32', 'float64'], 'polygon_box_transform'
)
helper = LayerHelper("polygon_box_transform", **locals())
output = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type="polygon_box_transform",
inputs={"Input": input},
attrs={},
outputs={"Output": output},
)
return output
def prior_box(
input,
image,
......
......@@ -75,51 +75,6 @@ class LayerTest(unittest.TestCase):
yield
class TestDetection(unittest.TestCase):
def test_box_coder_api(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[4], dtype='float32')
y = layers.data(name='z', shape=[4], dtype='float32', lod_level=1)
bcoder = layers.box_coder(
prior_box=x,
prior_box_var=[0.1, 0.2, 0.1, 0.2],
target_box=y,
code_type='encode_center_size',
)
self.assertIsNotNone(bcoder)
print(str(program))
def test_box_coder_error(self):
program = Program()
with program_guard(program):
x1 = fluid.data(name='x1', shape=[10, 4], dtype='int32')
y1 = fluid.data(
name='y1', shape=[10, 4], dtype='float32', lod_level=1
)
x2 = fluid.data(name='x2', shape=[10, 4], dtype='float32')
y2 = fluid.data(
name='y2', shape=[10, 4], dtype='int32', lod_level=1
)
self.assertRaises(
TypeError,
layers.box_coder,
prior_box=x1,
prior_box_var=[0.1, 0.2, 0.1, 0.2],
target_box=y1,
code_type='encode_center_size',
)
self.assertRaises(
TypeError,
layers.box_coder,
prior_box=x2,
prior_box_var=[0.1, 0.2, 0.1, 0.2],
target_box=y2,
code_type='encode_center_size',
)
class TestPriorBox(unittest.TestCase):
def test_prior_box(self):
program = Program()
......
......@@ -18,7 +18,6 @@ import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
def box_decoder(t_box, p_box, pb_v, output_box, norm, axis=0):
......@@ -114,7 +113,7 @@ class TestBoxCoderOp(OpTest):
def setUp(self):
self.op_type = "box_coder"
self.python_api = paddle.fluid.layers.box_coder
self.python_api = paddle.vision.ops.box_coder
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.random.random((81, 4)).astype('float32')
......@@ -146,7 +145,7 @@ class TestBoxCoderOpWithoutBoxVar(OpTest):
self.check_output(check_eager=True)
def setUp(self):
self.python_api = paddle.fluid.layers.box_coder
self.python_api = paddle.vision.ops.box_coder
self.op_type = "box_coder"
lod = [[0, 1, 2, 3, 4, 5]]
prior_box = np.random.random((81, 4)).astype('float32')
......@@ -180,7 +179,7 @@ class TestBoxCoderOpWithLoD(OpTest):
self.check_output(check_eager=True)
def setUp(self):
self.python_api = paddle.fluid.layers.box_coder
self.python_api = paddle.vision.ops.box_coder
self.op_type = "box_coder"
lod = [[10, 20, 20]]
prior_box = np.random.random((20, 4)).astype('float32')
......@@ -211,7 +210,7 @@ class TestBoxCoderOpWithAxis(OpTest):
self.check_output(check_eager=True)
def setUp(self):
self.python_api = paddle.fluid.layers.box_coder
self.python_api = paddle.vision.ops.box_coder
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((30, 4)).astype('float32')
......@@ -298,13 +297,13 @@ class TestBoxCoderOpWithVarianceDygraphAPI(unittest.TestCase):
self.axis,
)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
output_box = paddle.fluid.layers.box_coder(
output_box = paddle.vision.ops.box_coder(
paddle.to_tensor(self.prior_box),
self.prior_box_var.tolist(),
paddle.to_tensor(self.target_box),
......
......@@ -2294,14 +2294,6 @@ class TestBook(LayerTest):
return values
return indices
def make_polygon_box_transform(self):
with program_guard(
fluid.default_main_program(), fluid.default_startup_program()
):
x = self._get_data(name='x', shape=[8, 4, 4], dtype="float32")
output = layers.polygon_box_transform(input=x)
return output
def make_l2_normalize(self):
with program_guard(
fluid.default_main_program(), fluid.default_startup_program()
......
......@@ -17,8 +17,6 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
def PolygonBoxRestore(input):
shape = input.shape
......@@ -72,16 +70,5 @@ class TestCase2(TestPolygonBoxRestoreOp):
self.input_shape = (3, 12, 4, 5)
class TestPolygonBoxInvalidInput(unittest.TestCase):
def test_error(self):
def test_invalid_input():
input = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='int64'
)
out = fluid.layers.polygon_box_transform(input)
self.assertRaises(TypeError, test_invalid_input)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册