diff --git a/python/paddle/v2/fluid/layers/detection.py b/python/paddle/v2/fluid/layers/detection.py index 054443cb435fa2578b3056995e8a7b7b7eba83ff..bbe2765e138e0e05b9fdb2967822cefd337bf100 100644 --- a/python/paddle/v2/fluid/layers/detection.py +++ b/python/paddle/v2/fluid/layers/detection.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +16,19 @@ All layers just related to the detection neural network. """ from ..layer_helper import LayerHelper +from ..param_attr import ParamAttr +from ..framework import Variable +from layer_function_generator import autodoc +from tensor import concat +from ops import reshape +from ..nets import img_conv_with_bn +from nn import transpose +import math -__all__ = ['detection_output', ] +__all__ = [ + 'detection_output', + 'multi_box_head', +] def detection_output(scores, @@ -114,3 +125,147 @@ def detection_output(scores, 'nms_eta': 1.0 }) return nmsed_outs + + +def multi_box_head(inputs, + num_classes, + min_sizes=None, + max_sizes=None, + min_ratio=None, + max_ratio=None, + aspect_ratios=None, + flip=False, + share_location=True, + kernel_size=1, + pad=1, + stride=1, + use_batchnorm=False, + base_size=None, + name=None): + """ + **Multi Box Head** + + input many Variable, and return mbox_loc, mbox_conf + + Args: + inputs(list): The list of input Variables, the format + of all Variables is NCHW. + num_classes(int): The number of calss. + min_sizes(list, optional, default=None): The length of + min_size is used to compute the the number of prior box. + If the min_size is None, it will be computed according + to min_ratio and max_ratio. + max_sizes(list, optional, default=None): The length of max_size + is used to compute the the number of prior box. + min_ratio(int): If the min_sizes is None, min_ratio and min_ratio + will be used to compute the min_sizes and max_sizes. + max_ratio(int): If the min_sizes is None, min_ratio and min_ratio + will be used to compute the min_sizes and max_sizes. + aspect_ratios(list): The number of the aspect ratios is used to + compute the number of prior box. + base_size(int): the base_size is used to get min_size + and max_size according to min_ratio and max_ratio. + flip(bool, optional, default=False): Whether to flip + aspect ratios. + name(str, optional, None): Name of the prior box layer. + + Returns: + + mbox_loc(Variable): the output prior boxes of PriorBoxOp. The layout is + [num_priors, 4]. num_priors is the total box count of each + position of inputs. + mbox_conf(Variable): the expanded variances of PriorBoxOp. The layout + is [num_priors, 4]. num_priors is the total box count of each + position of inputs + + Examples: + .. code-block:: python + + + """ + + assert isinstance(inputs, list), 'inputs should be a list.' + + if min_sizes is not None: + assert len(inputs) == len(min_sizes) + + if max_sizes is not None: + assert len(inputs) == len(max_sizes) + + if min_sizes is None: + # if min_sizes is None, min_sizes and max_sizes + # will be set according to max_ratio and min_ratio + assert max_ratio is not None and min_ratio is not None + min_sizes = [] + max_sizes = [] + num_layer = len(inputs) + step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2))) + for ratio in xrange(min_ratio, max_ratio + 1, step): + min_sizes.append(base_size * ratio / 100.) + max_sizes.append(base_size * (ratio + step) / 100.) + min_sizes = [base_size * .10] + min_sizes + max_sizes = [base_size * .20] + max_sizes + + if aspect_ratios is not None: + assert len(inputs) == len(aspect_ratios) + + mbox_locs = [] + mbox_confs = [] + for i, input in enumerate(inputs): + min_size = min_sizes[i] + if type(min_size) is not list: + min_size = [min_size] + + max_size = [] + if max_sizes is not None: + max_size = max_sizes[i] + if type(max_size) is not list: + max_size = [max_size] + if max_size: + assert len(max_size) == len( + min_size), "max_size and min_size should have same length." + + aspect_ratio = [] + if aspect_ratios is not None: + aspect_ratio = aspect_ratios[i] + if type(aspect_ratio) is not list: + aspect_ratio = [aspect_ratio] + + num_priors_per_location = 0 + if max_sizes is not None: + num_priors_per_location = len(min_size) + len(aspect_ratio) * len( + min_size) + len(max_size) + else: + num_priors_per_location = len(min_size) + len(aspect_ratio) * len( + min_size) + if flip: + num_priors_per_location += len(aspect_ratio) * len(min_size) + + # mbox_loc + num_loc_output = num_priors_per_location * 4 + if share_location: + num_loc_output *= num_classes + + mbox_loc = img_conv_with_bn( + input=input, + conv_num_filter=num_loc_output, + conv_padding=pad, + conv_stride=stride, + conv_filter_size=kernel_size, + conv_with_batchnorm=use_batchnorm) + mbox_loc = transpose(mbox_loc, perm=[0, 2, 3, 1]) + mbox_locs.append(mbox_loc) + + # get the number of prior box + num_conf_output = num_priors_per_location * num_classes + conf_loc = img_conv_with_bn( + input=input, + conv_num_filter=num_conf_output, + conv_padding=pad, + conv_stride=stride, + conv_filter_size=kernel_size, + conv_with_batchnorm=use_batchnorm) + conf_loc = transpose(conf_loc, perm=[0, 2, 3, 1]) + mbox_confs.append(conf_loc) + + return mbox_locs, mbox_confs diff --git a/python/paddle/v2/fluid/nets.py b/python/paddle/v2/fluid/nets.py index be7878f869b509fa1117e305aee662cc0123bbcc..b7deccfd1f55b3064b7c5b89c25b83f2845cd17c 100644 --- a/python/paddle/v2/fluid/nets.py +++ b/python/paddle/v2/fluid/nets.py @@ -18,6 +18,7 @@ __all__ = [ "sequence_conv_pool", "glu", "scaled_dot_product_attention", + "img_conv_with_bn", ] @@ -107,6 +108,38 @@ def img_conv_group(input, return pool_out +def img_conv_with_bn(input, + conv_num_filter, + conv_padding=1, + conv_filter_size=3, + conv_stride=1, + conv_act=None, + param_attr=None, + conv_with_batchnorm=False, + conv_batchnorm_drop_rate=0.0, + use_cudnn=True): + """ + Image Convolution Group, Used for vgg net. + """ + conv2d = layers.conv2d( + input=input, + num_filters=conv_num_filter, + filter_size=conv_filter_size, + padding=conv_padding, + stride=conv_stride, + param_attr=param_attr, + act=conv_act, + use_cudnn=use_cudnn) + + if conv_with_batchnorm: + conv2d = layers.batch_norm(input=conv2d) + drop_rate = conv_batchnorm_drop_rate + if abs(drop_rate) > 1e-5: + conv2d = layers.dropout(x=conv2d, dropout_prob=drop_rate) + + return conv2d + + def sequence_conv_pool(input, num_filters, filter_size, diff --git a/python/paddle/v2/fluid/tests/test_detection.py b/python/paddle/v2/fluid/tests/test_detection.py index 75498ad7703614d2438ce7a521b8d1bc53c70f4b..d2207f1bfa9c73d0ededd8f353066e50e59c0522 100644 --- a/python/paddle/v2/fluid/tests/test_detection.py +++ b/python/paddle/v2/fluid/tests/test_detection.py @@ -13,7 +13,13 @@ # limitations under the License. from __future__ import print_function +import paddle.v2.fluid as fluid +import paddle.v2.fluid.core as core +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.layers.detection as detection +from paddle.v2.fluid.framework import Program, program_guard import unittest +import numpy as np import paddle.v2.fluid.layers as layers from paddle.v2.fluid.framework import Program, program_guard @@ -49,5 +55,60 @@ class TestBook(unittest.TestCase): print(str(program)) +class TestMultiBoxHead(unittest.TestCase): + def test_prior_box(self): + data_shape = [3, 224, 224] + mbox_locs, mbox_confs = self.multi_box_output(data_shape) + # print mbox_locs.shape + # print mbox_confs.shape + # assert len(box.shape) == 2 + # assert box.shape == var.shape + # assert box.shape[1] == 4 + + def multi_box_output(self, data_shape): + images = fluid.layers.data( + name='pixel', shape=data_shape, dtype='float32') + conv1 = fluid.layers.conv2d( + input=images, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + conv2 = fluid.layers.conv2d( + input=conv1, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + conv3 = fluid.layers.conv2d( + input=conv2, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + conv4 = fluid.layers.conv2d( + input=conv3, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + conv5 = fluid.layers.conv2d( + input=conv4, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + + mbox_locs, mbox_confs = detection.multi_box_head( + inputs=[conv1, conv2, conv3, conv4, conv5, conv5], + num_classes=21, + min_ratio=20, + max_ratio=90, + aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], + base_size=300, + flip=True) + return mbox_locs, mbox_confs + + if __name__ == '__main__': unittest.main()