未验证 提交 86657dbe 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #8382 from chengduoZH/feature/multiBoxHead

Add MultiBox API
......@@ -17,13 +17,13 @@ All layers just related to the detection neural network.
from layer_function_generator import generate_layer_fn
from ..layer_helper import LayerHelper
import nn
import ops
import tensor
import ops
import nn
import math
__all__ = [
'prior_box',
'multi_box_head',
'bipartite_match',
'target_assign',
'detection_output',
......@@ -54,7 +54,7 @@ def detection_output(scores,
"""
**Detection Output Layer**
This layer applies the NMS to the output of network and computes the
This layer applies the NMS to the output of network and computes the
predict bounding box location. The output's shape of this layer could
be zero if there is no valid bounding box.
......@@ -132,211 +132,6 @@ def detection_output(scores,
return nmsed_outs
def prior_box(inputs,
image,
min_ratio,
max_ratio,
aspect_ratios,
base_size,
steps=None,
step_w=None,
step_h=None,
offset=0.5,
variance=[0.1, 0.1, 0.1, 0.1],
flip=False,
clip=False,
min_sizes=None,
max_sizes=None,
name=None):
"""
**Prior_boxes**
Generate prior boxes for SSD(Single Shot MultiBox Detector)
algorithm. The details of this algorithm, please refer the
section 2.2 of SSD paper (SSD: Single Shot MultiBox Detector)
<https://arxiv.org/abs/1512.02325>`_ .
Args:
inputs(list): The list of input Variables, the format
of all Variables is NCHW.
image(Variable): The input image data of PriorBoxOp,
the layout is NCHW.
min_ratio(int): the min ratio of generated prior boxes.
max_ratio(int): the max ratio of generated prior boxes.
aspect_ratios(list): the aspect ratios of generated prior
boxes. The length of input and aspect_ratios must be equal.
base_size(int): the base_size is used to get min_size
and max_size according to min_ratio and max_ratio.
step_w(list, optional, default=None): Prior boxes step
across width. If step_w[i] == 0.0, the prior boxes step
across width of the inputs[i] will be automatically calculated.
step_h(list, optional, default=None): Prior boxes step
across height, If step_h[i] == 0.0, the prior boxes
step across height of the inputs[i] will be automatically calculated.
offset(float, optional, default=0.5): Prior boxes center offset.
variance(list, optional, default=[0.1, 0.1, 0.1, 0.1]): the variances
to be encoded in prior boxes.
flip(bool, optional, default=False): Whether to flip
aspect ratios.
clip(bool, optional, default=False): Whether to clip
out-of-boundary boxes.
min_sizes(list, optional, default=None): If `len(inputs) <=2`,
min_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
max_sizes(list, optional, default=None): If `len(inputs) <=2`,
max_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
name(str, optional, None): Name of the prior box layer.
Returns:
boxes(Variable): the output prior boxes of PriorBoxOp.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs.
Variances(Variable): the expanded variances of PriorBoxOp.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs
Examples:
.. code-block:: python
prior_box(
inputs = [conv1, conv2, conv3, conv4, conv5, conv6],
image = data,
min_ratio = 20, # 0.20
max_ratio = 90, # 0.90
offset = 0.5,
base_size = 300,
variance = [0.1,0.1,0.1,0.1],
aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
flip=True,
clip=True)
"""
def _prior_box_(input,
image,
min_sizes,
max_sizes,
aspect_ratios,
variance,
flip=False,
clip=False,
step_w=0.0,
step_h=0.0,
offset=0.5,
name=None):
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
box = helper.create_tmp_variable(dtype)
var = helper.create_tmp_variable(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input,
"Image": image},
outputs={"Boxes": box,
"Variances": var},
attrs={
'min_sizes': min_sizes,
'max_sizes': max_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': step_w,
'step_h': step_h,
'offset': offset
})
return box, var
def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)):
raise ValueError("The axis should be smaller than "
"the arity of input and bigger than 0.")
new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = ops.reshape(x=input, shape=new_shape)
return out
assert isinstance(inputs, list), 'inputs should be a list.'
num_layer = len(inputs)
if num_layer <= 2:
assert min_sizes is not None and max_sizes is not None
assert len(min_sizes) == num_layer and len(max_sizes) == num_layer
else:
min_sizes = []
max_sizes = []
step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2)))
for ratio in xrange(min_ratio, max_ratio + 1, step):
min_sizes.append(base_size * ratio / 100.)
max_sizes.append(base_size * (ratio + step) / 100.)
min_sizes = [base_size * .10] + min_sizes
max_sizes = [base_size * .20] + max_sizes
if aspect_ratios:
if not (isinstance(aspect_ratios, list) and
len(aspect_ratios) == num_layer):
raise ValueError(
'aspect_ratios should be list and the length of inputs '
'and aspect_ratios should be the same.')
if step_h:
if not (isinstance(step_h, list) and len(step_h) == num_layer):
raise ValueError(
'step_h should be list and the length of inputs and '
'step_h should be the same.')
if step_w:
if not (isinstance(step_w, list) and len(step_w) == num_layer):
raise ValueError(
'step_w should be list and the length of inputs and '
'step_w should be the same.')
if steps:
if not (isinstance(steps, list) and len(steps) == num_layer):
raise ValueError(
'steps should be list and the length of inputs and '
'step_w should be the same.')
step_w = steps
step_h = steps
box_results = []
var_results = []
for i, input in enumerate(inputs):
min_size = min_sizes[i]
max_size = max_sizes[i]
aspect_ratio = []
if not isinstance(min_size, list):
min_size = [min_size]
if not isinstance(max_size, list):
max_size = [max_size]
if aspect_ratios:
aspect_ratio = aspect_ratios[i]
if not isinstance(aspect_ratio, list):
aspect_ratio = [aspect_ratio]
box, var = _prior_box_(input, image, min_size, max_size, aspect_ratio,
variance, flip, clip, step_w[i]
if step_w else 0.0, step_h[i]
if step_w else 0.0, offset)
box_results.append(box)
var_results.append(var)
if len(box_results) == 1:
box = box_results[0]
var = var_results[0]
else:
reshaped_boxes = []
reshaped_vars = []
for i in range(len(box_results)):
reshaped_boxes.append(_reshape_with_axis_(box_results[i], axis=3))
reshaped_vars.append(_reshape_with_axis_(var_results[i], axis=3))
box = tensor.concat(reshaped_boxes)
var = tensor.concat(reshaped_vars)
return box, var
def bipartite_match(dist_matrix, name=None):
"""
**Bipartite matchint operator**
......@@ -348,13 +143,13 @@ def bipartite_match(dist_matrix, name=None):
each column. And this operator only calculate matched indices from column
to row. For each instance, the number of matched indices is the number of
of columns of the input ditance matrix.
There are two outputs to save matched indices and distance.
A simple description, this algothrim matched the best (maximum distance)
row entity to the column entity and the matched indices are not duplicated
in each row of ColToRowMatchIndices. If the column entity is not matched
any row entity, set -1 in ColToRowMatchIndices.
Please note that the input DistMat can be LoDTensor (with LoD) or Tensor.
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
If Tensor, the height of ColToRowMatchIndices is 1.
......@@ -407,30 +202,30 @@ def target_assign(input,
to assign classification and regression targets to each prediction as well as
weights to prediction. The weights is used to specify which prediction would
not contribute to training loss.
For each instance, the output `out` and`out_weight` are assigned based on
`match_indices` and `negative_indices`.
Assumed that the row offset for each instance in `input` is called lod,
this operator assigns classification/regression targets by performing the
following steps:
1. Assigning all outpts based on `match_indices`:
If id = match_indices[i][j] > 0,
out[i][j][0 : K] = X[lod[i] + id][j % P][0 : K]
out_weight[i][j] = 1.
Otherwise,
Otherwise,
out[j][j][0 : K] = {mismatch_value, mismatch_value, ...}
out_weight[i][j] = 0.
2. Assigning out_weight based on `neg_indices` if `neg_indices` is provided:
Assumed that the row offset for each instance in `neg_indices` is called neg_lod,
for i-th instance and each `id` of neg_indices in this instance:
out[i][id][0 : K] = {mismatch_value, mismatch_value, ...}
out_weight[i][id] = 1.0
......@@ -660,3 +455,263 @@ def ssd_loss(location,
# 5.3 Compute overall weighted loss.
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
return loss
def multi_box_head(inputs,
image,
base_size,
num_classes,
aspect_ratios,
min_ratio,
max_ratio,
min_sizes=None,
max_sizes=None,
steps=None,
step_w=None,
step_h=None,
offset=0.5,
variance=[0.1, 0.1, 0.1, 0.1],
flip=False,
clip=False,
kernel_size=1,
pad=0,
stride=1,
name=None):
"""
**Prior_boxes**
Generate prior boxes for SSD(Single Shot MultiBox Detector)
algorithm. The details of this algorithm, please refer the
section 2.2 of SSD paper (SSD: Single Shot MultiBox Detector)
<https://arxiv.org/abs/1512.02325>`_ .
Args:
inputs(list|tuple): The list of input Variables, the format
of all Variables is NCHW.
image(Variable): The input image data of PriorBoxOp,
the layout is NCHW.
base_size(int): the base_size is used to get min_size
and max_size according to min_ratio and max_ratio.
num_classes(int): The number of classes.
aspect_ratios(list|tuple): the aspect ratios of generated prior
boxes. The length of input and aspect_ratios must be equal.
min_ratio(int): the min ratio of generated prior boxes.
max_ratio(int): the max ratio of generated prior boxes.
min_sizes(list|tuple|None): If `len(inputs) <=2`,
min_sizes must be set up, and the length of min_sizes
should equal to the length of inputs. Default: None.
max_sizes(list|tuple|None): If `len(inputs) <=2`,
max_sizes must be set up, and the length of min_sizes
should equal to the length of inputs. Default: None.
steps(list|tuple): If step_w and step_h are the same,
step_w and step_h can be replaced by steps.
step_w(list|tuple): Prior boxes step
across width. If step_w[i] == 0.0, the prior boxes step
across width of the inputs[i] will be automatically
calculated. Default: None.
step_h(list|tuple): Prior boxes step across height, If
step_h[i] == 0.0, the prior boxes step across height of
the inputs[i] will be automatically calculated. Default: None.
offset(float): Prior boxes center offset. Default: 0.5
variance(list|tuple): the variances to be encoded in prior boxes.
Default:[0.1, 0.1, 0.1, 0.1].
flip(bool): Whether to flip aspect ratios. Default:False.
clip(bool): Whether to clip out-of-boundary boxes. Default: False.
kernel_size(int): The kernel size of conv2d. Default: 1.
pad(int|list|tuple): The padding of conv2d. Default:0.
stride(int|list|tuple): The stride of conv2d. Default:1,
name(str): Name of the prior box layer. Default: None.
Returns:
mbox_loc(list): The predicted boxes' location of the inputs.
The layout of each element is [N, H, W, Priors]. Priors
is the number of predicted boxof each position of each input.
mbox_conf(list): The predicted boxes' confidence of the inputs.
The layout of each element is [N, H, W, Priors]. Priors
is the number of predicted box of each position of each input.
boxes(Variable): the output prior boxes of PriorBox.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs.
Variances(Variable): the expanded variances of PriorBox.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs
Examples:
.. code-block:: python
mbox_locs, mbox_confs, box, var = layers.multi_box_head(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images,
num_classes=21,
min_ratio=20,
max_ratio=90,
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
offset=0.5,
flip=True,
clip=True)
"""
def _prior_box_(input,
image,
min_sizes,
max_sizes,
aspect_ratios,
variance,
flip=False,
clip=False,
step_w=0.0,
step_h=0.0,
offset=0.5,
name=None):
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
box = helper.create_tmp_variable(dtype)
var = helper.create_tmp_variable(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input,
"Image": image},
outputs={"Boxes": box,
"Variances": var},
attrs={
'min_sizes': min_sizes,
'max_sizes': max_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': step_w,
'step_h': step_h,
'offset': offset
})
return box, var
def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)):
raise ValueError("The axis should be smaller than "
"the arity of input and bigger than 0.")
new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = ops.reshape(x=input, shape=new_shape)
return out
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
def _is_list_or_tuple_and_equal(data, length, err_info):
if not (_is_list_or_tuple_(data) and len(data) == length):
raise ValueError(err_info)
if not _is_list_or_tuple_(inputs):
raise ValueError('inputs should be a list or tuple.')
num_layer = len(inputs)
if num_layer <= 2:
assert min_sizes is not None and max_sizes is not None
assert len(min_sizes) == num_layer and len(max_sizes) == num_layer
else:
min_sizes = []
max_sizes = []
step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2)))
for ratio in xrange(min_ratio, max_ratio + 1, step):
min_sizes.append(base_size * ratio / 100.)
max_sizes.append(base_size * (ratio + step) / 100.)
min_sizes = [base_size * .10] + min_sizes
max_sizes = [base_size * .20] + max_sizes
if aspect_ratios:
_is_list_or_tuple_and_equal(
aspect_ratios, num_layer,
'aspect_ratios should be list or tuple, and the length of inputs '
'and aspect_ratios should be the same.')
if step_h:
_is_list_or_tuple_and_equal(
step_h, num_layer,
'step_h should be list or tuple, and the length of inputs and '
'step_h should be the same.')
if step_w:
_is_list_or_tuple_and_equal(
step_w, num_layer,
'step_w should be list or tuple, and the length of inputs and '
'step_w should be the same.')
if steps:
_is_list_or_tuple_and_equal(
steps, num_layer,
'steps should be list or tuple, and the length of inputs and '
'step_w should be the same.')
step_w = steps
step_h = steps
mbox_locs = []
mbox_confs = []
box_results = []
var_results = []
for i, input in enumerate(inputs):
min_size = min_sizes[i]
max_size = max_sizes[i]
if not _is_list_or_tuple_(min_size):
min_size = [min_size]
if not _is_list_or_tuple_(max_size):
max_size = [max_size]
if not (len(max_size) == len(min_size)):
raise ValueError(
'the length of max_size and min_size should be equal.')
aspect_ratio = []
if aspect_ratios is not None:
aspect_ratio = aspect_ratios[i]
if not _is_list_or_tuple_(aspect_ratio):
aspect_ratio = [aspect_ratio]
box, var = _prior_box_(input, image, min_size, max_size, aspect_ratio,
variance, flip, clip, step_w[i]
if step_w else 0.0, step_h[i]
if step_w else 0.0, offset)
box_results.append(box)
var_results.append(var)
num_boxes = box.shape[2]
# get box_loc
num_loc_output = num_boxes * num_classes * 4
mbox_loc = nn.conv2d(
input=input,
num_filters=num_loc_output,
filter_size=kernel_size,
padding=pad,
stride=stride)
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
mbox_locs.append(mbox_loc)
# get conf_loc
num_conf_output = num_boxes * num_classes
conf_loc = nn.conv2d(
input=input,
num_filters=num_conf_output,
filter_size=kernel_size,
padding=pad,
stride=stride)
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
mbox_confs.append(conf_loc)
if len(box_results) == 1:
box = box_results[0]
var = var_results[0]
else:
reshaped_boxes = []
reshaped_vars = []
for i in range(len(box_results)):
reshaped_boxes.append(_reshape_with_axis_(box_results[i], axis=3))
reshaped_vars.append(_reshape_with_axis_(var_results[i], axis=3))
box = tensor.concat(reshaped_boxes)
var = tensor.concat(reshaped_vars)
return mbox_locs, mbox_confs, box, var
......@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import print_function
import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.framework import Program, program_guard
import unittest
......@@ -108,60 +109,40 @@ class TestDetection(unittest.TestCase):
print(str(program))
class TestPriorBox(unittest.TestCase):
def test_prior_box(self):
class TestMultiBoxHead(unittest.TestCase):
def test_multi_box_head(self):
data_shape = [3, 224, 224]
box, var = self.prior_box_output(data_shape)
mbox_locs, mbox_confs, box, var = self.multi_box_head_output(data_shape)
assert len(box.shape) == 2
assert box.shape == var.shape
assert box.shape[1] == 4
def prior_box_output(self, data_shape):
images = layers.data(name='pixel', shape=data_shape, dtype='float32')
conv1 = layers.conv2d(
input=images,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv2 = layers.conv2d(
input=conv1,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv3 = layers.conv2d(
input=conv2,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv4 = layers.conv2d(
input=conv3,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv5 = layers.conv2d(
input=conv4,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
box, var = layers.prior_box(
for loc, conf in zip(mbox_locs, mbox_confs):
assert loc.shape[1:3] == conf.shape[1:3]
def multi_box_head_output(self, data_shape):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
conv2 = fluid.layers.conv2d(conv1, 3, 3, 2)
conv3 = fluid.layers.conv2d(conv2, 3, 3, 2)
conv4 = fluid.layers.conv2d(conv3, 3, 3, 2)
conv5 = fluid.layers.conv2d(conv4, 3, 3, 2)
mbox_locs, mbox_confs, box, var = layers.multi_box_head(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images,
num_classes=21,
min_ratio=20,
max_ratio=90,
# steps=[8, 16, 32, 64, 100, 300],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
offset=0.5,
flip=True,
clip=True)
return box, var
return mbox_locs, mbox_confs, box, var
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册