未验证 提交 36fac289 编写于 作者: W wangguanzhong 提交者: GitHub

add box_coder (#1631)

上级 3fc2622d
...@@ -253,7 +253,7 @@ class Proposal(object): ...@@ -253,7 +253,7 @@ class Proposal(object):
bbox_delta_s = fluid.layers.slice( bbox_delta_s = fluid.layers.slice(
bbox_delta_r, axes=[1], starts=[1], ends=[2]) bbox_delta_r, axes=[1], starts=[1], ends=[2])
refined_bbox = fluid.layers.box_coder( refined_bbox = ops.box_coder(
prior_box=rois, prior_box=rois,
prior_box_var=self.proposal_target_generator.bbox_reg_weights[ prior_box_var=self.proposal_target_generator.bbox_reg_weights[
stage], stage],
......
...@@ -31,7 +31,7 @@ __all__ = [ ...@@ -31,7 +31,7 @@ __all__ = [
'anchor_generator', 'anchor_generator',
#'generate_proposals', #'generate_proposals',
'iou_similarity', 'iou_similarity',
#'box_coder', 'box_coder',
'yolo_box', 'yolo_box',
'multiclass_nms', 'multiclass_nms',
'distribute_fpn_proposals', 'distribute_fpn_proposals',
...@@ -1217,3 +1217,154 @@ def matrix_nms(bboxes, ...@@ -1217,3 +1217,154 @@ def matrix_nms(bboxes,
if return_rois_num: if return_rois_num:
return output, rois_num return output, rois_num
return output return output
def box_coder(prior_box,
prior_box_var,
target_box,
code_type="encode_center_size",
box_normalized=True,
axis=0,
name=None):
"""
**Box Coder Layer**
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
.. math::
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = \log(\abs(tw / pw)) / pwv
oh = \log(\abs(th / ph)) / phv
The Decoding schema described below:
.. math::
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = \exp(pwv * tw) * pw + tw / 2
oh = \exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates,
width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote
the priorbox's (anchor) center coordinates, width and height. `pxv`,
`pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`,
`ow`, `oh` denote the encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target
box has shape [N, M, 4], and the shape of prior box can be [N, 4] or
[M, 4]. Then prior box will broadcast to target box along the
assigned axis.
Args:
prior_box(Tensor): Box list prior_box is a 2-D Tensor with shape
[M, 4] holds M boxes and data type is float32 or float64. Each box
is represented as [xmin, ymin, xmax, ymax], [xmin, ymin] is the
left top coordinate of the anchor box, if the input is image feature
map, they are close to the origin of the coordinate system.
[xmax, ymax] is the right bottom coordinate of the anchor box.
prior_box_var(List|Tensor|None): prior_box_var supports three types
of input. One is Tensor with shape [M, 4] which holds M group and
data type is float32 or float64. The second is list consist of
4 elements shared by all boxes and data type is float32 or float64.
Other is None and not involved in calculation.
target_box(Tensor): This input can be a 2-D LoDTensor with shape
[N, 4] when code_type is 'encode_center_size'. This input also can
be a 3-D Tensor with shape [N, M, 4] when code_type is
'decode_center_size'. Each box is represented as
[xmin, ymin, xmax, ymax]. The data type is float32 or float64.
code_type(str): The code type used with the target box. It can be
`encode_center_size` or `decode_center_size`. `encode_center_size`
by default.
box_normalized(bool): Whether treat the priorbox as a normalized box.
Set true by default.
axis(int): Which axis in PriorBox to broadcast for box decode,
for example, if axis is 0 and TargetBox has shape [N, M, 4] and
PriorBox has shape [M, 4], then PriorBox will broadcast to [N, M, 4]
for decoding. It is only valid when code type is
`decode_center_size`. Set 0 by default.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor:
output_box(Tensor): When code_type is 'encode_center_size', the
output tensor of box_coder_op with shape [N, M, 4] representing the
result of N target boxes encoded with M Prior boxes and variances.
When code_type is 'decode_center_size', N represents the batch size
and M represents the number of decoded boxes.
Examples:
.. code-block:: python
import paddle
from ppdet.modeling import ops
paddle.enable_static()
# For encode
prior_box_encode = paddle.static.data(name='prior_box_encode',
shape=[512, 4],
dtype='float32')
target_box_encode = paddle.static.data(name='target_box_encode',
shape=[81, 4],
dtype='float32')
output_encode = ops.box_coder(prior_box=prior_box_encode,
prior_box_var=[0.1,0.1,0.2,0.2],
target_box=target_box_encode,
code_type="encode_center_size")
# For decode
prior_box_decode = paddle.static.data(name='prior_box_decode',
shape=[512, 4],
dtype='float32')
target_box_decode = paddle.static.data(name='target_box_decode',
shape=[512, 81, 4],
dtype='float32')
output_decode = ops.box_coder(prior_box=prior_box_decode,
prior_box_var=[0.1,0.1,0.2,0.2],
target_box=target_box_decode,
code_type="decode_center_size",
box_normalized=False,
axis=1)
"""
check_variable_and_dtype(prior_box, 'prior_box', ['float32', 'float64'],
'box_coder')
check_variable_and_dtype(target_box, 'target_box', ['float32', 'float64'],
'box_coder')
if in_dygraph_mode():
if isinstance(prior_box_var, Variable):
output_box = core.ops.box_coder(
prior_box, prior_box_var, target_box, "code_type", code_type,
"box_normalized", box_normalized, "axis", axis)
elif isinstance(prior_box_var, list):
output_box = core.ops.box_coder(
prior_box, target_box, "code_type", code_type, "box_normalized",
box_normalized, "axis", axis, "variance", prior_box_var)
else:
raise TypeError(
"Input variance of box_coder must be Variable or list")
return output_box
helper = LayerHelper("box_coder", **locals())
output_box = helper.create_variable_for_type_inference(
dtype=prior_box.dtype)
inputs = {"PriorBox": prior_box, "TargetBox": target_box}
attrs = {
"code_type": code_type,
"box_normalized": box_normalized,
"axis": axis
}
if isinstance(prior_box_var, Variable):
inputs['PriorBoxVar'] = prior_box_var
elif isinstance(prior_box_var, list):
attrs['variance'] = prior_box_var
else:
raise TypeError("Input variance of box_coder must be Variable or list")
helper.append_op(
type="box_coder",
inputs=inputs,
attrs=attrs,
outputs={"OutputBox": output_box})
return output_box
...@@ -44,6 +44,7 @@ class LayerTest(unittest.TestCase): ...@@ -44,6 +44,7 @@ class LayerTest(unittest.TestCase):
@contextlib.contextmanager @contextlib.contextmanager
def static_graph(self): def static_graph(self):
paddle.enable_static()
scope = fluid.core.Scope() scope = fluid.core.Scope()
program = Program() program = Program()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
...@@ -66,6 +67,7 @@ class LayerTest(unittest.TestCase): ...@@ -66,6 +67,7 @@ class LayerTest(unittest.TestCase):
@contextlib.contextmanager @contextlib.contextmanager
def dynamic_graph(self, force_to_use_cpu=False): def dynamic_graph(self, force_to_use_cpu=False):
paddle.disable_static()
with fluid.dygraph.guard( with fluid.dygraph.guard(
self._get_place(force_to_use_cpu=force_to_use_cpu)): self._get_place(force_to_use_cpu=force_to_use_cpu)):
paddle.seed(self.seed) paddle.seed(self.seed)
......
...@@ -64,7 +64,6 @@ class TestCollectFpnProposals(LayerTest): ...@@ -64,7 +64,6 @@ class TestCollectFpnProposals(LayerTest):
multi_scores_np.append(scores_np) multi_scores_np.append(scores_np)
rois_num_per_level_np.append(rois_num) rois_num_per_level_np.append(rois_num)
paddle.enable_static()
with self.static_graph(): with self.static_graph():
multi_bboxes = [] multi_bboxes = []
multi_scores = [] multi_scores = []
...@@ -104,7 +103,6 @@ class TestCollectFpnProposals(LayerTest): ...@@ -104,7 +103,6 @@ class TestCollectFpnProposals(LayerTest):
fpn_rois_stat = np.array(fpn_rois_stat) fpn_rois_stat = np.array(fpn_rois_stat)
rois_num_stat = np.array(rois_num_stat) rois_num_stat = np.array(rois_num_stat)
paddle.disable_static()
with self.dynamic_graph(): with self.dynamic_graph():
multi_bboxes_dy = [] multi_bboxes_dy = []
multi_scores_dy = [] multi_scores_dy = []
...@@ -148,9 +146,7 @@ class TestCollectFpnProposals(LayerTest): ...@@ -148,9 +146,7 @@ class TestCollectFpnProposals(LayerTest):
multi_scores.append(scores) multi_scores.append(scores)
return multi_bboxes, multi_scores return multi_bboxes, multi_scores
paddle.enable_static() with self.static_graph():
program = Program()
with program_guard(program):
bbox1 = paddle.static.data( bbox1 = paddle.static.data(
name='rois', shape=[5, 10, 4], dtype='float32', lod_level=1) name='rois', shape=[5, 10, 4], dtype='float32', lod_level=1)
score1 = paddle.static.data( score1 = paddle.static.data(
...@@ -223,8 +219,7 @@ class TestDistributeFpnProposals(LayerTest): ...@@ -223,8 +219,7 @@ class TestDistributeFpnProposals(LayerTest):
self.assertTrue(np.array_equal(res_stat, res_dy)) self.assertTrue(np.array_equal(res_stat, res_dy))
def test_distribute_fpn_proposals_error(self): def test_distribute_fpn_proposals_error(self):
program = Program() with self.static_graph():
with program_guard(program):
fpn_rois = paddle.static.data( fpn_rois = paddle.static.data(
name='data_error', shape=[10, 4], dtype='int32', lod_level=1) name='data_error', shape=[10, 4], dtype='int32', lod_level=1)
self.assertRaises( self.assertRaises(
...@@ -282,8 +277,7 @@ class TestROIAlign(LayerTest): ...@@ -282,8 +277,7 @@ class TestROIAlign(LayerTest):
self.assertTrue(np.array_equal(output_np, output_dy_np)) self.assertTrue(np.array_equal(output_np, output_dy_np))
def test_roi_align_error(self): def test_roi_align_error(self):
program = Program() with self.static_graph():
with program_guard(program):
inputs = paddle.static.data( inputs = paddle.static.data(
name='inputs', shape=[2, 12, 20, 20], dtype='float32') name='inputs', shape=[2, 12, 20, 20], dtype='float32')
rois = paddle.static.data( rois = paddle.static.data(
...@@ -341,8 +335,7 @@ class TestROIPool(LayerTest): ...@@ -341,8 +335,7 @@ class TestROIPool(LayerTest):
self.assertTrue(np.array_equal(output_np, output_dy_np)) self.assertTrue(np.array_equal(output_np, output_dy_np))
def test_roi_pool_error(self): def test_roi_pool_error(self):
program = Program() with self.static_graph():
with program_guard(program):
inputs = paddle.static.data( inputs = paddle.static.data(
name='inputs', shape=[2, 12, 20, 20], dtype='float32') name='inputs', shape=[2, 12, 20, 20], dtype='float32')
rois = paddle.static.data( rois = paddle.static.data(
...@@ -383,7 +376,7 @@ class TestIoUSimilarity(LayerTest): ...@@ -383,7 +376,7 @@ class TestIoUSimilarity(LayerTest):
self.assertTrue(np.array_equal(iou_np, iou_dy_np)) self.assertTrue(np.array_equal(iou_np, iou_dy_np))
class TestYOLOBox(LayerTest): class TestYoloBox(LayerTest):
def test_yolo_box(self): def test_yolo_box(self):
# x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2 # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2
...@@ -438,9 +431,7 @@ class TestYOLOBox(LayerTest): ...@@ -438,9 +431,7 @@ class TestYOLOBox(LayerTest):
self.assertTrue(np.array_equal(scores_np, scores_dy_np)) self.assertTrue(np.array_equal(scores_np, scores_dy_np))
def test_yolo_box_error(self): def test_yolo_box_error(self):
paddle.enable_static() with self.static_graph():
program = Program()
with program_guard(program):
# x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2 # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2
x = paddle.static.data( x = paddle.static.data(
name='x', shape=[1, 30, 7, 7], dtype='float32') name='x', shape=[1, 30, 7, 7], dtype='float32')
...@@ -521,7 +512,6 @@ class TestAnchorGenerator(LayerTest): ...@@ -521,7 +512,6 @@ class TestAnchorGenerator(LayerTest):
def test_anchor_generator(self): def test_anchor_generator(self):
b, c, h, w = 2, 48, 16, 16 b, c, h, w = 2, 48, 16, 16
input_np = np.random.rand(2, 48, 16, 16).astype('float32') input_np = np.random.rand(2, 48, 16, 16).astype('float32')
paddle.enable_static()
with self.static_graph(): with self.static_graph():
input = paddle.static.data( input = paddle.static.data(
name='input', shape=[b, c, h, w], dtype='float32') name='input', shape=[b, c, h, w], dtype='float32')
...@@ -712,5 +702,67 @@ class TestMatrixNMS(LayerTest): ...@@ -712,5 +702,67 @@ class TestMatrixNMS(LayerTest):
return_index=True) return_index=True)
class TestBoxCoder(LayerTest):
def test_box_coder(self):
prior_box_np = np.random.random((81, 4)).astype('float32')
prior_box_var_np = np.random.random((81, 4)).astype('float32')
target_box_np = np.random.random((20, 81, 4)).astype('float32')
# static
with self.static_graph():
prior_box = paddle.static.data(
name='prior_box', shape=[81, 4], dtype='float32')
prior_box_var = paddle.static.data(
name='prior_box_var', shape=[81, 4], dtype='float32')
target_box = paddle.static.data(
name='target_box', shape=[20, 81, 4], dtype='float32')
boxes = ops.box_coder(
prior_box=prior_box,
prior_box_var=prior_box_var,
target_box=target_box,
code_type="decode_center_size",
box_normalized=False)
boxes_np, = self.get_static_graph_result(
feed={
'prior_box': prior_box_np,
'prior_box_var': prior_box_var_np,
'target_box': target_box_np,
},
fetch_list=[boxes],
with_lod=False)
# dygraph
with self.dynamic_graph():
prior_box_dy = base.to_variable(prior_box_np)
prior_box_var_dy = base.to_variable(prior_box_var_np)
target_box_dy = base.to_variable(target_box_np)
boxes_dy = ops.box_coder(
prior_box=prior_box_dy,
prior_box_var=prior_box_var_dy,
target_box=target_box_dy,
code_type="decode_center_size",
box_normalized=False)
boxes_dy_np = boxes_dy.numpy()
self.assertTrue(np.array_equal(boxes_np, boxes_dy_np))
def test_box_coder_error(self):
with self.static_graph():
prior_box = paddle.static.data(
name='prior_box', shape=[81, 4], dtype='int32')
prior_box_var = paddle.static.data(
name='prior_box_var', shape=[81, 4], dtype='float32')
target_box = paddle.static.data(
name='target_box', shape=[20, 81, 4], dtype='float32')
self.assertRaises(TypeError, ops.box_coder, prior_box,
prior_box_var, target_box)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册