提交 e9fa7a7b 编写于 作者: C chengduoZH

follow comments of qingqing and code refine

上级 99c9dbf5
......@@ -18,10 +18,9 @@ All layers just related to the detection neural network.
from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
from ..framework import Variable
from ..nets import img_conv_with_bn
from tensor import concat
from ops import reshape
from nn import transpose
import tensor
import ops
import nn
import math
__all__ = [
......@@ -184,10 +183,10 @@ def prior_box(inputs,
name(str, optional, None): Name of the prior box layer.
Returns:
boxes(Variable): the output prior boxes of PriorBoxOp.
boxes(Variable): the output prior boxes of PriorBox.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs.
Variances(Variable): the expanded variances of PriorBoxOp.
Variances(Variable): the expanded variances of PriorBox.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs
......@@ -250,7 +249,7 @@ def prior_box(inputs,
new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = reshape(x=input, shape=new_shape)
out = ops.reshape(x=input, shape=new_shape)
return out
assert isinstance(inputs, list), 'inputs should be a list.'
......@@ -326,8 +325,8 @@ def prior_box(inputs,
reshaped_boxes.append(_reshape_with_axis_(box_results[i], axis=3))
reshaped_vars.append(_reshape_with_axis_(var_results[i], axis=3))
box = concat(reshaped_boxes)
var = concat(reshaped_vars)
box = tensor.concat(reshaped_boxes)
var = tensor.concat(reshaped_vars)
return box, var
......@@ -345,12 +344,14 @@ def multi_box_head(inputs,
pad=1,
stride=1,
use_batchnorm=False,
base_size=None,
name=None):
base_size=None):
"""
**Multi Box Head**
input many Variable, and return mbox_loc, mbox_conf
Generate prior boxes' location and confidence for SSD(Single
Shot MultiBox Detector)algorithm. The details of this algorithm,
please refer the section 2.1 of SSD paper (SSD: Single Shot
MultiBox Detector)<https://arxiv.org/abs/1512.02325>`_ .
Args:
inputs(list): The list of input Variables, the format
......@@ -376,12 +377,12 @@ def multi_box_head(inputs,
Returns:
mbox_loc(list): the output prior boxes of PriorBoxOp. The layout is
[num_priors, 4]. num_priors is the total box count of each
position of inputs.
mbox_conf(list): the expanded variances of PriorBoxOp. The layout
is [num_priors, 4]. num_priors is the total box count of each
position of inputs
mbox_loc(list): The predicted boxes' location of the inputs.
The layout of each element is [N, H, W, Priors]. Priors
is the number of predicted boxof each position of each input.
mbox_conf(list): The predicted boxes' confidence of the inputs.
The layout of each element is [N, H, W, Priors]. Priors
is the number of predicted box of each position of each input.
Examples:
.. code-block:: python
......@@ -396,6 +397,35 @@ def multi_box_head(inputs,
flip=True)
"""
def _conv_with_bn_(input,
conv_num_filter,
conv_padding=1,
conv_filter_size=3,
conv_stride=1,
conv_act=None,
param_attr=None,
conv_with_batchnorm=False,
conv_batchnorm_drop_rate=0.0,
use_cudnn=True):
conv2d = nn.conv2d(
input=input,
num_filters=conv_num_filter,
filter_size=conv_filter_size,
padding=conv_padding,
stride=conv_stride,
param_attr=param_attr,
act=conv_act,
use_cudnn=use_cudnn)
if conv_with_batchnorm:
conv2d = nn.batch_norm(input=conv2d)
drop_rate = conv_batchnorm_drop_rate
if abs(drop_rate) > 1e-5:
conv2d = nn.dropout(x=conv2d, dropout_prob=drop_rate)
return conv2d
if not (isinstance(inputs, list)):
raise ValueError('inputs should be a list.')
......@@ -469,26 +499,26 @@ def multi_box_head(inputs,
if share_location:
num_loc_output *= num_classes
mbox_loc = img_conv_with_bn(
mbox_loc = _conv_with_bn_(
input=input,
conv_num_filter=num_loc_output,
conv_padding=pad,
conv_stride=stride,
conv_filter_size=kernel_size,
conv_with_batchnorm=use_batchnorm)
mbox_loc = transpose(mbox_loc, perm=[0, 2, 3, 1])
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
mbox_locs.append(mbox_loc)
# get conf_loc
num_conf_output = num_priors_per_location * num_classes
conf_loc = img_conv_with_bn(
conf_loc = _conv_with_bn_(
input=input,
conv_num_filter=num_conf_output,
conv_padding=pad,
conv_stride=stride,
conv_filter_size=kernel_size,
conv_with_batchnorm=use_batchnorm)
conf_loc = transpose(conf_loc, perm=[0, 2, 3, 1])
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
mbox_confs.append(conf_loc)
return mbox_locs, mbox_confs
......@@ -47,7 +47,7 @@ class TestBook(unittest.TestCase):
out = layers.detection_output(
scores=scores, loc=loc, prior_box=pb, prior_box_var=pbv)
self.assertIsNotNone(out)
print(str(program))
# print(str(program))
class TestPriorBox(unittest.TestCase):
......@@ -62,36 +62,11 @@ class TestPriorBox(unittest.TestCase):
def prior_box_output(self, data_shape):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(
input=images,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv2 = fluid.layers.conv2d(
input=conv1,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv3 = fluid.layers.conv2d(
input=conv2,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv4 = fluid.layers.conv2d(
input=conv3,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv5 = fluid.layers.conv2d(
input=conv4,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
conv2 = fluid.layers.conv2d(conv1, 3, 3, 2)
conv3 = fluid.layers.conv2d(conv2, 3, 3, 2)
conv4 = fluid.layers.conv2d(conv3, 3, 3, 2)
conv5 = fluid.layers.conv2d(conv4, 3, 3, 2)
box, var = detection.prior_box(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
......@@ -112,39 +87,17 @@ class TestMultiBoxHead(unittest.TestCase):
data_shape = [3, 224, 224]
mbox_locs, mbox_confs = self.multi_box_output(data_shape)
for loc, conf in zip(mbox_locs, mbox_confs):
assert loc.shape[1:3] == conf.shape[1:3]
def multi_box_output(self, data_shape):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(
input=images,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv2 = fluid.layers.conv2d(
input=conv1,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv3 = fluid.layers.conv2d(
input=conv2,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv4 = fluid.layers.conv2d(
input=conv3,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv5 = fluid.layers.conv2d(
input=conv4,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
conv2 = fluid.layers.conv2d(conv1, 3, 3, 2)
conv3 = fluid.layers.conv2d(conv2, 3, 3, 2)
conv4 = fluid.layers.conv2d(conv3, 3, 3, 2)
conv5 = fluid.layers.conv2d(conv4, 3, 3, 2)
mbox_locs, mbox_confs = detection.multi_box_head(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册