提交 3b0e43aa 编写于 作者: C chengduoZH

add config parse

上级 0c273ff4
...@@ -515,6 +515,8 @@ message LayerConfig { ...@@ -515,6 +515,8 @@ message LayerConfig {
// for HuberRegressionLoss // for HuberRegressionLoss
optional double delta = 57 [ default = 1.0 ]; optional double delta = 57 [ default = 1.0 ];
// for 3D data
optional double depth = 58 [ default = 1 ];
} }
message EvaluatorConfig { message EvaluatorConfig {
......
...@@ -1172,6 +1172,20 @@ def get_img_size(input_layer_name, channels): ...@@ -1172,6 +1172,20 @@ def get_img_size(input_layer_name, channels):
return img_size, img_size_y return img_size, img_size_y
def get_img3d_size(input_layer_name, channels):
input = g_layer_map[input_layer_name]
img_pixels = input.size / channels
img_size = input.width
img_size_y = input.height
img_size_z = input.depth
config_assert(
img_size * img_size_y * img_size_z == img_pixels,
"Input layer %s: Incorrect input image size %d * %d * %d for input image pixels %d"
% (input_layer_name, img_size, img_size_y, img_size_z, img_pixels))
return img_size, img_size_y, img_size_z
def parse_bilinear(bilinear, input_layer_name, bilinear_conf): def parse_bilinear(bilinear, input_layer_name, bilinear_conf):
parse_image(bilinear, input_layer_name, bilinear_conf.image_conf) parse_image(bilinear, input_layer_name, bilinear_conf.image_conf)
bilinear_conf.out_size_x = bilinear.out_size_x bilinear_conf.out_size_x = bilinear.out_size_x
...@@ -1224,6 +1238,12 @@ def parse_image(image, input_layer_name, image_conf): ...@@ -1224,6 +1238,12 @@ def parse_image(image, input_layer_name, image_conf):
get_img_size(input_layer_name, image_conf.channels) get_img_size(input_layer_name, image_conf.channels)
def parse_image3d(image, input_layer_name, image_conf):
image_conf.channels = image.channels
image_conf.img_size, image_conf.img_size_y, image_conf.img_size_z = \
get_img3d_size(input_layer_name, image_conf.channels)
def parse_norm(norm, input_layer_name, norm_conf): def parse_norm(norm, input_layer_name, norm_conf):
norm_conf.norm_type = norm.norm_type norm_conf.norm_type = norm.norm_type
config_assert( config_assert(
...@@ -1585,6 +1605,9 @@ class LayerBase(object): ...@@ -1585,6 +1605,9 @@ class LayerBase(object):
self.config.height = height self.config.height = height
self.config.width = width self.config.width = width
def set_layer_depth(self, depth):
self.config.depth = depth
def set_cnn_layer(self, def set_cnn_layer(self,
input_layer_name, input_layer_name,
height, height,
...@@ -1788,11 +1811,19 @@ class DetectionOutputLayer(LayerBase): ...@@ -1788,11 +1811,19 @@ class DetectionOutputLayer(LayerBase):
@config_layer('data') @config_layer('data')
class DataLayer(LayerBase): class DataLayer(LayerBase):
def __init__(self, name, size, height=None, width=None, device=None): def __init__(self,
name,
size,
depth=None,
height=None,
width=None,
device=None):
super(DataLayer, self).__init__( super(DataLayer, self).__init__(
name, 'data', size, inputs=[], device=device) name, 'data', size, inputs=[], device=device)
if height and width: if height and width:
self.set_layer_height_width(height, width) self.set_layer_height_width(height, width)
if depth:
self.set_layer_depth(depth)
''' '''
...@@ -2077,6 +2108,7 @@ class BatchNormLayer(LayerBase): ...@@ -2077,6 +2108,7 @@ class BatchNormLayer(LayerBase):
name, name,
inputs, inputs,
bias=True, bias=True,
img3D=False,
use_global_stats=True, use_global_stats=True,
moving_average_fraction=0.9, moving_average_fraction=0.9,
batch_norm_type=None, batch_norm_type=None,
...@@ -2121,15 +2153,33 @@ class BatchNormLayer(LayerBase): ...@@ -2121,15 +2153,33 @@ class BatchNormLayer(LayerBase):
input_layer = self.get_input_layer(0) input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].image_conf image_conf = self.config.inputs[0].image_conf
parse_image(self.inputs[0].image, input_layer.name, image_conf) if img3D:
parse_image3d(self.inputs[0].image, input_layer.name, image_conf)
# Only pass the width and height of input to batch_norm layer # Only pass the width and height of input to batch_norm layer
# when either of it is non-zero. # when either of it is non-zero.
if input_layer.width != 0 or input_layer.height != 0: if input_layer.width != 0 or input_layer.height != 0:
self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size, self.set_cnn_layer(
image_conf.channels, False) input_layer_name=name,
depth=image_conf.img_size_z,
height=image_conf.img_size_y,
width=image_conf.img_size,
channels=image_conf.channels,
is_print=True)
else:
self.set_layer_size(input_layer.size)
else: else:
self.set_layer_size(input_layer.size) parse_image(self.inputs[0].image, input_layer.name, image_conf)
# Only pass the width and height of input to batch_norm layer
# when either of it is non-zero.
if input_layer.width != 0 or input_layer.height != 0:
self.set_cnn_layer(
input_layer_name=name,
height=image_conf.img_size_y,
width=image_conf.img_size,
channels=image_conf.channels,
is_print=True)
else:
self.set_layer_size(input_layer.size)
psize = self.calc_parameter_size(image_conf) psize = self.calc_parameter_size(image_conf)
dims = [1, psize] dims = [1, psize]
...@@ -2139,6 +2189,28 @@ class BatchNormLayer(LayerBase): ...@@ -2139,6 +2189,28 @@ class BatchNormLayer(LayerBase):
self.create_bias_parameter(bias, psize) self.create_bias_parameter(bias, psize)
def set_cnn_layer(self,
input_layer_name,
depth=None,
height=None,
width=None,
channels=None,
is_print=True):
depthIsNone = False
if depth is None:
depth = 1
depthIsNone = True
size = depth * height * width * channels
self.set_layer_size(size)
self.set_layer_height_width(height, width)
self.set_layer_depth(depth)
if is_print and depthIsNone:
print("output for %s: c = %d, h = %d, w = %d, size = %d" %
(input_layer_name, channels, height, width, size))
elif is_print:
print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
(input_layer_name, channels, depth, height, width, size))
def calc_parameter_size(self, image_conf): def calc_parameter_size(self, image_conf):
return image_conf.channels return image_conf.channels
......
...@@ -166,6 +166,7 @@ class LayerType(object): ...@@ -166,6 +166,7 @@ class LayerType(object):
EXCONVTRANS_LAYER = 'exconvt' EXCONVTRANS_LAYER = 'exconvt'
CUDNNCONV_LAYER = 'cudnn_conv' CUDNNCONV_LAYER = 'cudnn_conv'
POOL_LAYER = 'pool' POOL_LAYER = 'pool'
POOL3D_LAYER = 'pool3d'
BATCH_NORM_LAYER = 'batch_norm' BATCH_NORM_LAYER = 'batch_norm'
NORM_LAYER = 'norm' NORM_LAYER = 'norm'
SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm' SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm'
...@@ -894,7 +895,8 @@ def mixed_layer(size=0, ...@@ -894,7 +895,8 @@ def mixed_layer(size=0,
@layer_support() @layer_support()
def data_layer(name, size, height=None, width=None, layer_attr=None): def data_layer(name, size, depth=None, height=None, width=None,
layer_attr=None):
""" """
Define DataLayer For NeuralNetwork. Define DataLayer For NeuralNetwork.
...@@ -921,15 +923,18 @@ def data_layer(name, size, height=None, width=None, layer_attr=None): ...@@ -921,15 +923,18 @@ def data_layer(name, size, height=None, width=None, layer_attr=None):
type=LayerType.DATA, type=LayerType.DATA,
name=name, name=name,
size=size, size=size,
depth=depth,
height=height, height=height,
width=width, width=width,
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
if depth is None:
depth = 1
num_filters = None num_filters = None
if height is not None and width is not None: if height is not None and width is not None:
num_filters = size / (width * height) num_filters = size / (width * height * depth)
assert num_filters * width * height == size, \ assert num_filters * width * height * depth == size, \
"size=%s width=%s height=%s" % (size, width, height) "size=%s width=%s height=%s depth=%s" % (size, width, height, depth)
return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters) return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters)
...@@ -2799,6 +2804,7 @@ def img_cmrnorm_layer(input, ...@@ -2799,6 +2804,7 @@ def img_cmrnorm_layer(input,
def batch_norm_layer(input, def batch_norm_layer(input,
act=None, act=None,
name=None, name=None,
img3D=False,
num_channels=None, num_channels=None,
bias_attr=None, bias_attr=None,
param_attr=None, param_attr=None,
...@@ -2885,6 +2891,7 @@ def batch_norm_layer(input, ...@@ -2885,6 +2891,7 @@ def batch_norm_layer(input,
(batch_norm_type == "cudnn_batch_norm") (batch_norm_type == "cudnn_batch_norm")
l = Layer( l = Layer(
name=name, name=name,
img3D=img3D,
inputs=Input( inputs=Input(
input.name, image=Image(channels=num_channels), **param_attr.attr), input.name, image=Image(channels=num_channels), **param_attr.attr),
active_type=act.name, active_type=act.name,
......
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-4)
data = data_layer(name='data', size=180, width=30, height=6)
#
batchNorm = batch_norm_layer(data, num_channels=1)
#
outputs(batchNorm)
# #
data3D = data_layer(name='data3D22', size=120 * 3, width=20, height=6, depth=3)
#
print(data3D)
batchNorm3D = batch_norm_layer(data3D, num_channels=1, img3D=True)
#
outputs(batchNorm3D)
...@@ -16,4 +16,4 @@ from paddle.trainer.config_parser import parse_config_and_serialize ...@@ -16,4 +16,4 @@ from paddle.trainer.config_parser import parse_config_and_serialize
if __name__ == '__main__': if __name__ == '__main__':
parse_config_and_serialize( parse_config_and_serialize(
'trainer_config_helpers/tests/layers_test_config.py', '') 'trainer_config_helpers/tests/configs/test_BatchNorm3D.py', '')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册