提交 b1c0bad9 编写于 作者: C chengduoZH

Add config parser for pooling3D

上级 860bf192
......@@ -2255,9 +2255,7 @@ void CpuMatrix::maxPool3DBackward(Matrix& outGrad,
real* tgtGrad = getData();
real* otGrad = outGrad.getData();
real* maxPoolIdxData = maxPoolIdx.getData();
size_t outStride = outGrad.getStride();
;
for (size_t n = 0; n < num; ++n) {
if (!outGrad.isContiguous()) {
......
......@@ -495,6 +495,7 @@ message LayerConfig {
// to indicate rectangle image data
optional uint64 height = 50;
optional uint64 width = 51;
optional uint64 depth = 57 [ default = 1 ];
// blank label used in ctc loss
optional uint32 blank = 52 [ default = 0 ];
......
......@@ -903,6 +903,31 @@ class Pool(Cfg):
self.add_keys(locals())
@config_class
class Pool3d(Cfg):
def __init__(
self,
pool_type,
channels,
size_x,
size_y=None,
size_z=None,
start=None,
stride=None, # 1 by defalut in protobuf
stride_y=None,
stride_z=None,
padding=None, # 0 by defalut in protobuf
padding_y=None,
padding_z=None):
self.add_keys(locals())
self.filter_size_y = size_y if size_y else size_x
self.filter_size_z = size_z if size_z else size_x
self.padding_y = padding_y if padding_y else padding
self.padding_z = padding_z if padding_z else padding
self.stride_y = stride_y if stride_y else stride
self.stride_z = stride_z if stride_z else stride
@config_class
class SpatialPyramidPool(Cfg):
def __init__(self, pool_type, pyramid_height, channels):
......@@ -1167,6 +1192,20 @@ def get_img_size(input_layer_name, channels):
return img_size, img_size_y
def get_img3d_size(input_layer_name, channels):
input = g_layer_map[input_layer_name]
img_pixels = input.size / channels
img_size = input.width
img_size_y = input.height
img_size_z = input.depth
config_assert(
img_size * img_size_y * img_size_z == img_pixels,
"Input layer %s: Incorrect input image size %d * %d * %d for input image pixels %d"
% (input_layer_name, img_size, img_size_y, img_size_z, img_pixels))
return img_size, img_size_y, img_size_z
def parse_bilinear(bilinear, input_layer_name, bilinear_conf):
parse_image(bilinear, input_layer_name, bilinear_conf.image_conf)
bilinear_conf.out_size_x = bilinear.out_size_x
......@@ -1204,6 +1243,45 @@ def parse_pool(pool, input_layer_name, pool_conf, ceil_mode):
pool_conf.stride_y, not ceil_mode)
def parse_pool3d(pool, input_layer_name, pool_conf, ceil_mode):
pool_conf.pool_type = pool.pool_type
config_assert(pool.pool_type in ['max-projection', 'avg-projection'],
"pool-type %s is not in "
"['max-projection', 'avg-projection']" % pool.pool_type)
pool_conf.channels = pool.channels
pool_conf.size_x = pool.size_x
pool_conf.stride = pool.stride
pool_conf.padding = pool.padding
pool_conf.size_y = default(pool.size_y, pool_conf.size_x)
pool_conf.size_z = default(pool.size_z, pool_conf.size_x)
pool_conf.stride_y = default(pool.stride_y, pool_conf.stride)
pool_conf.stride_z = default(pool.stride_z, pool_conf.stride)
pool_conf.padding_y = default(pool.padding_y, pool_conf.padding)
pool_conf.padding_z = default(pool.padding_z, pool_conf.padding)
pool_conf.img_size, pool_conf.img_size_y, pool_conf.img_size_z = \
get_img3d_size(input_layer_name, pool.channels)
config_assert(not pool.start, "start is deprecated in pooling.")
if pool.padding is not None:
pool_conf.padding = pool.padding
pool_conf.padding_y = default(pool.padding_y, pool_conf.padding)
pool_conf.padding_z = default(pool.padding_z, pool_conf.padding)
pool_conf.output_x = cnn_output_size(pool_conf.img_size, pool_conf.size_x,
pool_conf.padding, pool_conf.stride,
not ceil_mode)
pool_conf.output_y = cnn_output_size(pool_conf.img_size_y, pool_conf.size_y,
pool_conf.padding_y,
pool_conf.stride_y, not ceil_mode)
pool_conf.output_z = cnn_output_size(pool_conf.img_size_z, pool_conf.size_z,
pool_conf.padding_z,
pool_conf.stride_z, not ceil_mode)
def parse_spp(spp, input_layer_name, spp_conf):
parse_image(spp, input_layer_name, spp_conf.image_conf)
spp_conf.pool_type = spp.pool_type
......@@ -1580,6 +1658,9 @@ class LayerBase(object):
self.config.height = height
self.config.width = width
def set_layer_depth(self, depth):
self.config.depth = depth
def set_cnn_layer(self,
input_layer_name,
height,
......@@ -1763,11 +1844,19 @@ class DetectionOutputLayer(LayerBase):
@config_layer('data')
class DataLayer(LayerBase):
def __init__(self, name, size, height=None, width=None, device=None):
def __init__(self,
name,
size,
depth=None,
height=None,
width=None,
device=None):
super(DataLayer, self).__init__(
name, 'data', size, inputs=[], device=device)
if height and width:
self.set_layer_height_width(height, width)
if depth:
self.set_layer_depth(depth)
'''
......@@ -1995,6 +2084,35 @@ class PoolLayer(LayerBase):
pool_conf.channels)
@config_layer('pool3d')
class Pool3DLayer(LayerBase):
def __init__(self, name, inputs, ceil_mode=True, **xargs):
super(Pool3DLayer, self).__init__(
name, 'pool3d', 0, inputs=inputs, **xargs)
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
pool_conf = self.config.inputs[input_index].pool_conf
parse_pool3d(self.inputs[input_index].pool, input_layer.name,
pool_conf, ceil_mode)
self.set_cnn_layer(name, pool_conf.output_z, pool_conf.output_y,
pool_conf.output_x, pool_conf.channels)
def set_cnn_layer(self,
input_layer_name,
depth,
height,
width,
channels,
is_print=True):
size = depth * height * width * channels
self.set_layer_size(size)
self.set_layer_height_width(height, width)
self.set_layer_depth(depth)
if is_print:
print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
(input_layer_name, channels, depth, height, width, size))
@config_layer('spp')
class SpatialPyramidPoolLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
......
......@@ -133,6 +133,7 @@ __all__ = [
'clip_layer',
'slice_projection',
'kmax_sequence_score_layer',
'img_pool3d_layer',
]
......@@ -161,6 +162,7 @@ class LayerType(object):
EXCONVTRANS_LAYER = 'exconvt'
CUDNNCONV_LAYER = 'cudnn_conv'
POOL_LAYER = 'pool'
POOL3D_LAYER = 'pool3d'
BATCH_NORM_LAYER = 'batch_norm'
NORM_LAYER = 'norm'
SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm'
......@@ -878,7 +880,8 @@ def mixed_layer(size=0,
@layer_support()
def data_layer(name, size, height=None, width=None, layer_attr=None):
def data_layer(name, size, depth=None, height=None, width=None,
layer_attr=None):
"""
Define DataLayer For NeuralNetwork.
......@@ -905,6 +908,7 @@ def data_layer(name, size, height=None, width=None, layer_attr=None):
type=LayerType.DATA,
name=name,
size=size,
depth=depth,
height=height,
width=width,
**ExtraLayerAttribute.to_kwargs(layer_attr))
......@@ -2610,6 +2614,146 @@ def img_pool_layer(input,
size=l.config.size)
@wrap_name_default("pool3d")
@layer_support()
def img_pool3d_layer(input,
pool_size,
name=None,
num_channels=None,
pool_type=None,
stride=1,
padding=0,
layer_attr=None,
pool_size_y=None,
stride_y=None,
padding_y=None,
pool_size_z=None,
stride_z=None,
padding_z=None,
ceil_mode=True):
"""
Image pooling Layer.
The details of pooling layer, please refer ufldl's pooling_ .
.. _pooling: http://ufldl.stanford.edu/tutorial/supervised/Pooling/
- ceil_mode=True:
.. math::
w = 1 + int(ceil(input\_width + 2 * padding - pool\_size) / float(stride))
h = 1 + int(ceil(input\_height + 2 * padding\_y - pool\_size\_y) / float(stride\_y))
d = 1 + int(ceil(input\_depth + 2 * padding\_z - pool\_size\_z) / float(stride\_z))
- ceil_mode=False:
.. math::
w = 1 + int(floor(input\_width + 2 * padding - pool\_size) / float(stride))
h = 1 + int(floor(input\_height + 2 * padding\_y - pool\_size\_y) / float(stride\_y))
d = 1 + int(floor(input\_depth + 2 * padding\_z - pool\_size\_z) / float(stride\_z))
The example usage is:
.. code-block:: python
maxpool = img_pool3d_layer(input=conv,
pool_size=3,
num_channels=8,
stride=1,
padding=1,
pool_type=MaxPooling())
:param padding: pooling padding width.
:type padding: int|tuple|list
:param name: name of pooling layer
:type name: basestring.
:param input: layer's input
:type input: LayerOutput
:param pool_size: pooling window width
:type pool_size: int|tuple|list
:param num_channels: number of input channel.
:type num_channels: int
:param pool_type: pooling type. MaxPooling or AvgPooling. Default is
MaxPooling.
:type pool_type: BasePoolingType
:param stride: stride width of pooling.
:type stride: int|tuple|list
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:param ceil_mode: Wether to use ceil mode to calculate output height and with.
Defalut is True. If set false, Otherwise use floor.
:type ceil_mode: bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
if num_channels is None:
assert input.num_filters is not None
num_channels = input.num_filters
if pool_type is None:
pool_type = MaxPooling()
elif isinstance(pool_type, AvgPooling):
pool_type.name = 'avg'
type_name = pool_type.name + '-projection' \
if (
isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \
else pool_type.name
if isinstance(pool_size, collections.Sequence):
assert len(pool_size) == 3
pool_size, pool_size_y, pool_size_z = pool_size
else:
pool_size_y = pool_size
pool_size_z = pool_size
if isinstance(stride, collections.Sequence):
assert len(stride) == 3
stride, stride_y, stride_z = stride
else:
stride_y = stride
stride_z = stride
if isinstance(padding, collections.Sequence):
assert len(padding) == 3
padding, padding_y, padding_y = padding
else:
padding_y = padding
padding_z = padding
l = Layer(
name=name,
type=LayerType.POOL3D_LAYER,
inputs=[
Input(
input.name,
pool=Pool3d(
pool_type=type_name,
channels=num_channels,
size_x=pool_size,
start=None,
stride=stride,
padding=padding,
size_y=pool_size_y,
stride_y=stride_y,
padding_y=padding_y,
size_z=pool_size_z,
stride_z=stride_z,
padding_z=padding_z))
],
ceil_mode=ceil_mode,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name,
LayerType.POOL_LAYER,
parents=[input],
num_filters=num_channels,
size=l.config.size)
@wrap_name_default("spp")
@layer_support()
def spp_layer(input,
......
from paddle.trainer_config_helpers import *
settings(batch_size=100, learning_rate=1e-5)
data_2d = data_layer(name='data_2d', size=6000, height=20, width=10)
pool_2d = img_pool_layer(
name="pool___2d",
input=data_2d,
num_channels=30,
pool_size=5,
stride=3,
padding=1,
pool_type=AvgPooling())
outputs(pool_2d)
data_3d = data_layer(
name='data_3d_1', size=60000, depth=10, height=20, width=10)
pool_3d_1 = img_pool3d_layer(
name="pool_3d_1",
input=data_3d,
num_channels=30,
pool_size=5,
stride=3,
padding=1,
pool_type=AvgPooling())
outputs(pool_3d_1)
pool_3d_2 = img_pool3d_layer(
name="pool_3d_2",
input=data_3d,
num_channels=30,
pool_size=[5, 5, 5],
stride=[3, 3, 3],
padding=[1, 1, 1],
pool_type=MaxPooling())
outputs(pool_3d_2)
......@@ -16,4 +16,4 @@ from paddle.trainer.config_parser import parse_config_and_serialize
if __name__ == '__main__':
parse_config_and_serialize(
'trainer_config_helpers/tests/layers_test_config.py', '')
'trainer_config_helpers/tests/configs/test_pooling3D_layer.py', '')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册