提交 49c50c9f 编写于 作者: C chengduoZH

Add multiBox API

上级 cb4eacb1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,8 +16,19 @@ All layers just related to the detection neural network. ...@@ -16,8 +16,19 @@ All layers just related to the detection neural network.
""" """
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
from ..framework import Variable
from layer_function_generator import autodoc
from tensor import concat
from ops import reshape
from ..nets import img_conv_with_bn
from nn import transpose
import math
__all__ = ['detection_output', ] __all__ = [
'detection_output',
'multi_box_head',
]
def detection_output(scores, def detection_output(scores,
...@@ -114,3 +125,147 @@ def detection_output(scores, ...@@ -114,3 +125,147 @@ def detection_output(scores,
'nms_eta': 1.0 'nms_eta': 1.0
}) })
return nmsed_outs return nmsed_outs
def multi_box_head(inputs,
num_classes,
min_sizes=None,
max_sizes=None,
min_ratio=None,
max_ratio=None,
aspect_ratios=None,
flip=False,
share_location=True,
kernel_size=1,
pad=1,
stride=1,
use_batchnorm=False,
base_size=None,
name=None):
"""
**Multi Box Head**
input many Variable, and return mbox_loc, mbox_conf
Args:
inputs(list): The list of input Variables, the format
of all Variables is NCHW.
num_classes(int): The number of calss.
min_sizes(list, optional, default=None): The length of
min_size is used to compute the the number of prior box.
If the min_size is None, it will be computed according
to min_ratio and max_ratio.
max_sizes(list, optional, default=None): The length of max_size
is used to compute the the number of prior box.
min_ratio(int): If the min_sizes is None, min_ratio and min_ratio
will be used to compute the min_sizes and max_sizes.
max_ratio(int): If the min_sizes is None, min_ratio and min_ratio
will be used to compute the min_sizes and max_sizes.
aspect_ratios(list): The number of the aspect ratios is used to
compute the number of prior box.
base_size(int): the base_size is used to get min_size
and max_size according to min_ratio and max_ratio.
flip(bool, optional, default=False): Whether to flip
aspect ratios.
name(str, optional, None): Name of the prior box layer.
Returns:
mbox_loc(Variable): the output prior boxes of PriorBoxOp. The layout is
[num_priors, 4]. num_priors is the total box count of each
position of inputs.
mbox_conf(Variable): the expanded variances of PriorBoxOp. The layout
is [num_priors, 4]. num_priors is the total box count of each
position of inputs
Examples:
.. code-block:: python
"""
assert isinstance(inputs, list), 'inputs should be a list.'
if min_sizes is not None:
assert len(inputs) == len(min_sizes)
if max_sizes is not None:
assert len(inputs) == len(max_sizes)
if min_sizes is None:
# if min_sizes is None, min_sizes and max_sizes
# will be set according to max_ratio and min_ratio
assert max_ratio is not None and min_ratio is not None
min_sizes = []
max_sizes = []
num_layer = len(inputs)
step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2)))
for ratio in xrange(min_ratio, max_ratio + 1, step):
min_sizes.append(base_size * ratio / 100.)
max_sizes.append(base_size * (ratio + step) / 100.)
min_sizes = [base_size * .10] + min_sizes
max_sizes = [base_size * .20] + max_sizes
if aspect_ratios is not None:
assert len(inputs) == len(aspect_ratios)
mbox_locs = []
mbox_confs = []
for i, input in enumerate(inputs):
min_size = min_sizes[i]
if type(min_size) is not list:
min_size = [min_size]
max_size = []
if max_sizes is not None:
max_size = max_sizes[i]
if type(max_size) is not list:
max_size = [max_size]
if max_size:
assert len(max_size) == len(
min_size), "max_size and min_size should have same length."
aspect_ratio = []
if aspect_ratios is not None:
aspect_ratio = aspect_ratios[i]
if type(aspect_ratio) is not list:
aspect_ratio = [aspect_ratio]
num_priors_per_location = 0
if max_sizes is not None:
num_priors_per_location = len(min_size) + len(aspect_ratio) * len(
min_size) + len(max_size)
else:
num_priors_per_location = len(min_size) + len(aspect_ratio) * len(
min_size)
if flip:
num_priors_per_location += len(aspect_ratio) * len(min_size)
# mbox_loc
num_loc_output = num_priors_per_location * 4
if share_location:
num_loc_output *= num_classes
mbox_loc = img_conv_with_bn(
input=input,
conv_num_filter=num_loc_output,
conv_padding=pad,
conv_stride=stride,
conv_filter_size=kernel_size,
conv_with_batchnorm=use_batchnorm)
mbox_loc = transpose(mbox_loc, perm=[0, 2, 3, 1])
mbox_locs.append(mbox_loc)
# get the number of prior box
num_conf_output = num_priors_per_location * num_classes
conf_loc = img_conv_with_bn(
input=input,
conv_num_filter=num_conf_output,
conv_padding=pad,
conv_stride=stride,
conv_filter_size=kernel_size,
conv_with_batchnorm=use_batchnorm)
conf_loc = transpose(conf_loc, perm=[0, 2, 3, 1])
mbox_confs.append(conf_loc)
return mbox_locs, mbox_confs
...@@ -18,6 +18,7 @@ __all__ = [ ...@@ -18,6 +18,7 @@ __all__ = [
"sequence_conv_pool", "sequence_conv_pool",
"glu", "glu",
"scaled_dot_product_attention", "scaled_dot_product_attention",
"img_conv_with_bn",
] ]
...@@ -107,6 +108,38 @@ def img_conv_group(input, ...@@ -107,6 +108,38 @@ def img_conv_group(input,
return pool_out return pool_out
def img_conv_with_bn(input,
conv_num_filter,
conv_padding=1,
conv_filter_size=3,
conv_stride=1,
conv_act=None,
param_attr=None,
conv_with_batchnorm=False,
conv_batchnorm_drop_rate=0.0,
use_cudnn=True):
"""
Image Convolution Group, Used for vgg net.
"""
conv2d = layers.conv2d(
input=input,
num_filters=conv_num_filter,
filter_size=conv_filter_size,
padding=conv_padding,
stride=conv_stride,
param_attr=param_attr,
act=conv_act,
use_cudnn=use_cudnn)
if conv_with_batchnorm:
conv2d = layers.batch_norm(input=conv2d)
drop_rate = conv_batchnorm_drop_rate
if abs(drop_rate) > 1e-5:
conv2d = layers.dropout(x=conv2d, dropout_prob=drop_rate)
return conv2d
def sequence_conv_pool(input, def sequence_conv_pool(input,
num_filters, num_filters,
filter_size, filter_size,
......
...@@ -13,7 +13,13 @@ ...@@ -13,7 +13,13 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.layers.detection as detection
from paddle.v2.fluid.framework import Program, program_guard
import unittest import unittest
import numpy as np
import paddle.v2.fluid.layers as layers import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.framework import Program, program_guard from paddle.v2.fluid.framework import Program, program_guard
...@@ -49,5 +55,60 @@ class TestBook(unittest.TestCase): ...@@ -49,5 +55,60 @@ class TestBook(unittest.TestCase):
print(str(program)) print(str(program))
class TestMultiBoxHead(unittest.TestCase):
def test_prior_box(self):
data_shape = [3, 224, 224]
mbox_locs, mbox_confs = self.multi_box_output(data_shape)
# print mbox_locs.shape
# print mbox_confs.shape
# assert len(box.shape) == 2
# assert box.shape == var.shape
# assert box.shape[1] == 4
def multi_box_output(self, data_shape):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(
input=images,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv2 = fluid.layers.conv2d(
input=conv1,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv3 = fluid.layers.conv2d(
input=conv2,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv4 = fluid.layers.conv2d(
input=conv3,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv5 = fluid.layers.conv2d(
input=conv4,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
mbox_locs, mbox_confs = detection.multi_box_head(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
num_classes=21,
min_ratio=20,
max_ratio=90,
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
flip=True)
return mbox_locs, mbox_confs
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册