提交 6e79d01b 编写于 作者: C chengduoZH

merge prior_box and multi_box

上级 8ea2288e
......@@ -23,7 +23,6 @@ import nn
import math
__all__ = [
'prior_box',
'multi_box_head',
'bipartite_match',
'target_assign',
......@@ -133,219 +132,6 @@ def detection_output(scores,
return nmsed_outs
def prior_box(inputs,
image,
min_ratio,
max_ratio,
aspect_ratios,
base_size,
steps=None,
step_w=None,
step_h=None,
offset=0.5,
variance=[0.1, 0.1, 0.1, 0.1],
flip=False,
clip=False,
min_sizes=None,
max_sizes=None,
name=None):
"""
**Prior_boxes**
Generate prior boxes for SSD(Single Shot MultiBox Detector)
algorithm. The details of this algorithm, please refer the
section 2.2 of SSD paper (SSD: Single Shot MultiBox Detector)
<https://arxiv.org/abs/1512.02325>`_ .
Args:
inputs(list|tuple): The list of input Variables, the format
of all Variables is NCHW.
image(Variable): The input image data of PriorBoxOp,
the layout is NCHW.
min_ratio(int): the min ratio of generated prior boxes.
max_ratio(int): the max ratio of generated prior boxes.
aspect_ratios(list|tuple): the aspect ratios of generated prior
boxes. The length of input and aspect_ratios must be equal.
base_size(int): the base_size is used to get min_size
and max_size according to min_ratio and max_ratio.
step_w(list|tuple|None): Prior boxes step
across width. If step_w[i] == 0.0, the prior boxes step
across width of the inputs[i] will be automatically calculated.
step_h(list|tuple|None): Prior boxes step
across height, If step_h[i] == 0.0, the prior boxes
step across height of the inputs[i] will be automatically calculated.
offset(float, optional, default=0.5): Prior boxes center offset.
variance(list|tuple|[0.1, 0.1, 0.1, 0.1]): the variances
to be encoded in prior boxes.
flip(bool|False): Whether to flip
aspect ratios.
clip(bool, optional, default=False): Whether to clip
out-of-boundary boxes.
min_sizes(list|tuple|None): If `len(inputs) <=2`,
min_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
max_sizes(list|tuple|None): If `len(inputs) <=2`,
max_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
name(str|None): Name of the prior box layer.
Returns:
boxes(Variable): the output prior boxes of PriorBox.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs.
Variances(Variable): the expanded variances of PriorBox.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs
Examples:
.. code-block:: python
prior_box(
inputs = [conv1, conv2, conv3, conv4, conv5, conv6],
image = data,
min_ratio = 20, # 0.20
max_ratio = 90, # 0.90
offset = 0.5,
base_size = 300,
variance = [0.1,0.1,0.1,0.1],
aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
flip=True,
clip=True)
"""
def _prior_box_(input,
image,
min_sizes,
max_sizes,
aspect_ratios,
variance,
flip=False,
clip=False,
step_w=0.0,
step_h=0.0,
offset=0.5,
name=None):
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
box = helper.create_tmp_variable(dtype)
var = helper.create_tmp_variable(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input,
"Image": image},
outputs={"Boxes": box,
"Variances": var},
attrs={
'min_sizes': min_sizes,
'max_sizes': max_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': step_w,
'step_h': step_h,
'offset': offset
})
return box, var
def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)):
raise ValueError("The axis should be smaller than "
"the arity of input and bigger than 0.")
new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = ops.reshape(x=input, shape=new_shape)
return out
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
def _is_list_or_tuple_and_equal(data, length, err_info):
if not (_is_list_or_tuple_(data) and len(data) == length):
raise ValueError(err_info)
if not _is_list_or_tuple_(inputs):
raise ValueError('inputs should be a list or tuple.')
num_layer = len(inputs)
if num_layer <= 2:
assert min_sizes is not None and max_sizes is not None
assert len(min_sizes) == num_layer and len(max_sizes) == num_layer
else:
min_sizes = []
max_sizes = []
step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2)))
for ratio in xrange(min_ratio, max_ratio + 1, step):
min_sizes.append(base_size * ratio / 100.)
max_sizes.append(base_size * (ratio + step) / 100.)
min_sizes = [base_size * .10] + min_sizes
max_sizes = [base_size * .20] + max_sizes
if aspect_ratios:
_is_list_or_tuple_and_equal(
aspect_ratios, num_layer,
'aspect_ratios should be list or tuple, and the length of inputs '
'and aspect_ratios should be the same.')
if step_h:
_is_list_or_tuple_and_equal(
step_h, num_layer,
'step_h should be list or tuple, and the length of inputs and '
'step_h should be the same.')
if step_w:
_is_list_or_tuple_and_equal(
step_w, num_layer,
'step_w should be list or tuple, and the length of inputs and '
'step_w should be the same.')
if steps:
_is_list_or_tuple_and_equal(
steps, num_layer,
'steps should be list or tuple, and the length of inputs and '
'step_w should be the same.')
step_w = steps
step_h = steps
box_results = []
var_results = []
for i, input in enumerate(inputs):
min_size = min_sizes[i]
max_size = max_sizes[i]
aspect_ratio = []
if not _is_list_or_tuple_(min_size):
min_size = [min_size]
if not _is_list_or_tuple_(max_size):
max_size = [max_size]
if aspect_ratios:
aspect_ratio = aspect_ratios[i]
if not _is_list_or_tuple_(aspect_ratio):
aspect_ratio = [aspect_ratio]
box, var = _prior_box_(input, image, min_size, max_size, aspect_ratio,
variance, flip, clip, step_w[i]
if step_w else 0.0, step_h[i]
if step_w else 0.0, offset)
box_results.append(box)
var_results.append(var)
if len(box_results) == 1:
box = box_results[0]
var = var_results[0]
else:
reshaped_boxes = []
reshaped_vars = []
for i in range(len(box_results)):
reshaped_boxes.append(_reshape_with_axis_(box_results[i], axis=3))
reshaped_vars.append(_reshape_with_axis_(var_results[i], axis=3))
box = tensor.concat(reshaped_boxes)
var = tensor.concat(reshaped_vars)
return box, var
def bipartite_match(dist_matrix, name=None):
"""
**Bipartite matchint operator**
......@@ -672,106 +458,162 @@ def ssd_loss(location,
def multi_box_head(inputs,
image,
base_size,
num_classes,
aspect_ratios,
min_ratio,
max_ratio,
min_sizes=None,
max_sizes=None,
min_ratio=None,
max_ratio=None,
aspect_ratios=None,
steps=None,
step_w=None,
step_h=None,
offset=0.5,
variance=[0.1, 0.1, 0.1, 0.1],
flip=False,
share_location=True,
clip=False,
kernel_size=1,
pad=1,
pad=0,
stride=1,
use_batchnorm=False,
base_size=None):
name=None):
"""
**Multi Box Head**
**Prior_boxes**
Generate prior boxes' location and confidence for SSD(Single
Shot MultiBox Detector)algorithm. The details of this algorithm,
please refer the section 2.1 of SSD paper (SSD: Single Shot
MultiBox Detector)<https://arxiv.org/abs/1512.02325>`_ .
Generate prior boxes for SSD(Single Shot MultiBox Detector)
algorithm. The details of this algorithm, please refer the
section 2.2 of SSD paper (SSD: Single Shot MultiBox Detector)
<https://arxiv.org/abs/1512.02325>`_ .
Args:
inputs(list|tuple): The list of input Variables, the format
of all Variables is NCHW.
num_classes(int): The number of classes.
min_sizes(list|tuple|None): The number of
min_sizes is used to compute the number of predicted box.
If the min_size is None, it will be computed according
to min_ratio and max_ratio.
max_sizes(list|tuple|None): The number of max_sizes
is used to compute the the number of predicted box.
min_ratio(int|None): If the min_sizes is None, min_ratio and max_ratio
will be used to compute the min_sizes and max_sizes.
max_ratio(int|None): If the min_sizes is None, max_ratio and min_ratio
will be used to compute the min_sizes and max_sizes.
aspect_ratios(list|tuple): The number of the aspect ratios is used to
compute the number of prior box.
image(Variable): The input image data of PriorBoxOp,
the layout is NCHW.
base_size(int): the base_size is used to get min_size
and max_size according to min_ratio and max_ratio.
flip(bool|False): Whether to flip
aspect ratios.
name(str|None): Name of the prior box layer.
num_classes(int): The number of classes.
aspect_ratios(list|tuple): the aspect ratios of generated prior
boxes. The length of input and aspect_ratios must be equal.
min_ratio(int): the min ratio of generated prior boxes.
max_ratio(int): the max ratio of generated prior boxes.
min_sizes(list|tuple|None): If `len(inputs) <=2`,
min_sizes must be set up, and the length of min_sizes
should equal to the length of inputs. Default: None.
max_sizes(list|tuple|None): If `len(inputs) <=2`,
max_sizes must be set up, and the length of min_sizes
should equal to the length of inputs. Default: None.
steps(list|tuple): If step_w and step_h are the same,
step_w and step_h can be replaced by steps.
step_w(list|tuple): Prior boxes step
across width. If step_w[i] == 0.0, the prior boxes step
across width of the inputs[i] will be automatically
calculated. Default: None.
step_h(list|tuple): Prior boxes step across height, If
step_h[i] == 0.0, the prior boxes step across height of
the inputs[i] will be automatically calculated. Default: None.
offset(float): Prior boxes center offset. Default: 0.5
variance(list|tuple): the variances to be encoded in prior boxes.
Default:[0.1, 0.1, 0.1, 0.1].
flip(bool): Whether to flip aspect ratios. Default:False.
clip(bool): Whether to clip out-of-boundary boxes. Default: False.
kernel_size(int): The kernel size of conv2d. Default: 1.
pad(int|list|tuple): The padding of conv2d. Default:0.
stride(int|list|tuple): The stride of conv2d. Default:1,
name(str): Name of the prior box layer. Default: None.
Returns:
mbox_loc(list): The predicted boxes' location of the inputs.
The layout of each element is [N, H, W, Priors]. Priors
is the number of predicted boxof each position of each input.
mbox_conf(list): The predicted boxes' confidence of the inputs.
The layout of each element is [N, H, W, Priors]. Priors
is the number of predicted box of each position of each input.
boxes(Variable): the output prior boxes of PriorBox.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs.
Variances(Variable): the expanded variances of PriorBox.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs
Examples:
.. code-block:: python
mbox_locs, mbox_confs = detection.multi_box_head(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
num_classes=21,
min_ratio=20,
max_ratio=90,
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
flip=True)
mbox_locs, mbox_confs, box, var = layers.multi_box_head(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images,
num_classes=21,
min_ratio=20,
max_ratio=90,
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
offset=0.5,
flip=True,
clip=True)
"""
def _is_equal_(len1, len2, err_info):
if not (len1 == len2):
raise ValueError(err_info)
def _prior_box_(input,
image,
min_sizes,
max_sizes,
aspect_ratios,
variance,
flip=False,
clip=False,
step_w=0.0,
step_h=0.0,
offset=0.5,
name=None):
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
box = helper.create_tmp_variable(dtype)
var = helper.create_tmp_variable(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input,
"Image": image},
outputs={"Boxes": box,
"Variances": var},
attrs={
'min_sizes': min_sizes,
'max_sizes': max_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': step_w,
'step_h': step_h,
'offset': offset
})
return box, var
def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)):
raise ValueError("The axis should be smaller than "
"the arity of input and bigger than 0.")
new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = ops.reshape(x=input, shape=new_shape)
return out
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
def _is_list_or_tuple_and_equal(data, length, err_info):
if not (_is_list_or_tuple_(data) and len(data) == length):
raise ValueError(err_info)
if not _is_list_or_tuple_(inputs):
raise ValueError('inputs should be a list or tuple.')
if min_sizes is not None:
_is_equal_(
len(inputs),
len(min_sizes), 'the length of min_sizes '
'and inputs should be equal.')
if max_sizes is not None:
_is_equal_(
len(inputs),
len(max_sizes), 'the length of max_sizes '
'and inputs should be equal.')
if aspect_ratios is not None:
_is_equal_(
len(inputs),
len(aspect_ratios), 'the length of aspect_ratios '
'and inputs should be equal.')
if min_sizes is None:
# If min_sizes is None, min_sizes and max_sizes
# will be set according to max_ratio and min_ratio.
num_layer = len(inputs)
assert max_ratio is not None and min_ratio is not None,\
'max_ratio and min_ratio must be not None.'
assert num_layer >= 3, 'The length of the input data is at least three.'
num_layer = len(inputs)
if num_layer <= 2:
assert min_sizes is not None and max_sizes is not None
assert len(min_sizes) == num_layer and len(max_sizes) == num_layer
else:
min_sizes = []
max_sizes = []
step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2)))
......@@ -781,21 +623,43 @@ def multi_box_head(inputs,
min_sizes = [base_size * .10] + min_sizes
max_sizes = [base_size * .20] + max_sizes
if aspect_ratios:
_is_list_or_tuple_and_equal(
aspect_ratios, num_layer,
'aspect_ratios should be list or tuple, and the length of inputs '
'and aspect_ratios should be the same.')
if step_h:
_is_list_or_tuple_and_equal(
step_h, num_layer,
'step_h should be list or tuple, and the length of inputs and '
'step_h should be the same.')
if step_w:
_is_list_or_tuple_and_equal(
step_w, num_layer,
'step_w should be list or tuple, and the length of inputs and '
'step_w should be the same.')
if steps:
_is_list_or_tuple_and_equal(
steps, num_layer,
'steps should be list or tuple, and the length of inputs and '
'step_w should be the same.')
step_w = steps
step_h = steps
mbox_locs = []
mbox_confs = []
box_results = []
var_results = []
for i, input in enumerate(inputs):
min_size = min_sizes[i]
max_size = max_sizes[i]
if not _is_list_or_tuple_(min_size):
min_size = [min_size]
max_size = []
if max_sizes is not None:
max_size = max_sizes[i]
if not _is_list_or_tuple_(max_size):
max_size = [max_size]
_is_equal_(
len(max_size),
len(min_size),
if not _is_list_or_tuple_(max_size):
max_size = [max_size]
if not (len(max_size) == len(min_size)):
raise ValueError(
'the length of max_size and min_size should be equal.')
aspect_ratio = []
......@@ -804,23 +668,18 @@ def multi_box_head(inputs,
if not _is_list_or_tuple_(aspect_ratio):
aspect_ratio = [aspect_ratio]
# get the number of prior box on each location
num_priors_per_location = 0
if max_sizes is not None:
num_priors_per_location = len(min_size) + \
len(aspect_ratio) * len(min_size) +\
len(max_size)
else:
num_priors_per_location = len(min_size) +\
len(aspect_ratio) * len(min_size)
if flip:
num_priors_per_location += len(aspect_ratio) * len(min_size)
# get mbox_loc
num_loc_output = num_priors_per_location * 4
if share_location:
num_loc_output *= num_classes
box, var = _prior_box_(input, image, min_size, max_size, aspect_ratio,
variance, flip, clip, step_w[i]
if step_w else 0.0, step_h[i]
if step_w else 0.0, offset)
box_results.append(box)
var_results.append(var)
num_boxes = box.shape[2]
# get box_loc
num_loc_output = num_boxes * num_classes * 4
mbox_loc = nn.conv2d(
input=input,
num_filters=num_loc_output,
......@@ -832,7 +691,7 @@ def multi_box_head(inputs,
mbox_locs.append(mbox_loc)
# get conf_loc
num_conf_output = num_priors_per_location * num_classes
num_conf_output = num_boxes * num_classes
conf_loc = nn.conv2d(
input=input,
num_filters=num_conf_output,
......@@ -842,4 +701,17 @@ def multi_box_head(inputs,
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
mbox_confs.append(conf_loc)
return mbox_locs, mbox_confs
if len(box_results) == 1:
box = box_results[0]
var = var_results[0]
else:
reshaped_boxes = []
reshaped_vars = []
for i in range(len(box_results)):
reshaped_boxes.append(_reshape_with_axis_(box_results[i], axis=3))
reshaped_vars.append(_reshape_with_axis_(var_results[i], axis=3))
box = tensor.concat(reshaped_boxes)
var = tensor.concat(reshaped_vars)
return mbox_locs, mbox_confs, box, var
......@@ -109,16 +109,19 @@ class TestDetection(unittest.TestCase):
print(str(program))
class TestPriorBox(unittest.TestCase):
def test_prior_box(self):
class TestMultiBoxHead(unittest.TestCase):
def test_multi_box_head(self):
data_shape = [3, 224, 224]
box, var = self.prior_box_output(data_shape)
mbox_locs, mbox_confs, box, var = self.multi_box_head_output(data_shape)
assert len(box.shape) == 2
assert box.shape == var.shape
assert box.shape[1] == 4
def prior_box_output(self, data_shape):
for loc, conf in zip(mbox_locs, mbox_confs):
assert loc.shape[1:3] == conf.shape[1:3]
def multi_box_head_output(self, data_shape):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
......@@ -127,46 +130,19 @@ class TestPriorBox(unittest.TestCase):
conv4 = fluid.layers.conv2d(conv3, 3, 3, 2)
conv5 = fluid.layers.conv2d(conv4, 3, 3, 2)
box, var = layers.prior_box(
mbox_locs, mbox_confs, box, var = layers.multi_box_head(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images,
num_classes=21,
min_ratio=20,
max_ratio=90,
# steps=[8, 16, 32, 64, 100, 300],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
offset=0.5,
flip=True,
clip=True)
return box, var
class TestMultiBoxHead(unittest.TestCase):
def test_prior_box(self):
data_shape = [3, 224, 224]
mbox_locs, mbox_confs = self.multi_box_output(data_shape)
for loc, conf in zip(mbox_locs, mbox_confs):
assert loc.shape[1:3] == conf.shape[1:3]
def multi_box_output(self, data_shape):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
conv2 = fluid.layers.conv2d(conv1, 3, 3, 2)
conv3 = fluid.layers.conv2d(conv2, 3, 3, 2)
conv4 = fluid.layers.conv2d(conv3, 3, 3, 2)
conv5 = fluid.layers.conv2d(conv4, 3, 3, 2)
mbox_locs, mbox_confs = detection.multi_box_head(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
num_classes=21,
min_ratio=20,
max_ratio=90,
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
flip=True)
return mbox_locs, mbox_confs
return mbox_locs, mbox_confs, box, var
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册