From 454a1a291279818d315a88ef7a91e019a99c73bf Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Wed, 14 Dec 2016 17:51:47 -0800 Subject: [PATCH] fixed a bug for demo/gan caused by batchNormLayer --- demo/gan/gan_conf_image.py | 4 ++-- python/paddle/trainer/config_parser.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/demo/gan/gan_conf_image.py b/demo/gan/gan_conf_image.py index f89a4e706c..c469227994 100644 --- a/demo/gan/gan_conf_image.py +++ b/demo/gan/gan_conf_image.py @@ -87,9 +87,9 @@ def conv_bn(input, print(imgSize, output_x, stride, filter_size, padding) if trans: - nameApx = "_conv" - else: nameApx = "_convt" + else: + nameApx = "_conv" if bn: conv = img_conv_layer( diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 5b7f4d85e2..ea3e4308fe 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1871,8 +1871,14 @@ class BatchNormLayer(LayerBase): input_layer = self.get_input_layer(0) image_conf = self.config.inputs[0].image_conf parse_image(self.inputs[0].image, input_layer.name, image_conf) - self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size, - image_conf.channels, False) + + # Only pass the width and height of input to batch_norm layer + # when either of it is non-zero. + if input_layer.width != 0 or input_layer.height != 0: + self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size, + image_conf.channels, True) + else: + self.set_layer_size(input_layer.size) psize = self.calc_parameter_size(image_conf) dims = [1, psize] -- GitLab