提交 454a1a29 编写于 作者: W wangyang59

fixed a bug for demo/gan caused by batchNormLayer

上级 dd894c29
...@@ -87,9 +87,9 @@ def conv_bn(input, ...@@ -87,9 +87,9 @@ def conv_bn(input,
print(imgSize, output_x, stride, filter_size, padding) print(imgSize, output_x, stride, filter_size, padding)
if trans: if trans:
nameApx = "_conv"
else:
nameApx = "_convt" nameApx = "_convt"
else:
nameApx = "_conv"
if bn: if bn:
conv = img_conv_layer( conv = img_conv_layer(
......
...@@ -1871,8 +1871,14 @@ class BatchNormLayer(LayerBase): ...@@ -1871,8 +1871,14 @@ class BatchNormLayer(LayerBase):
input_layer = self.get_input_layer(0) input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].image_conf image_conf = self.config.inputs[0].image_conf
parse_image(self.inputs[0].image, input_layer.name, image_conf) parse_image(self.inputs[0].image, input_layer.name, image_conf)
self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
image_conf.channels, False) # Only pass the width and height of input to batch_norm layer
# when either of it is non-zero.
if input_layer.width != 0 or input_layer.height != 0:
self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
image_conf.channels, True)
else:
self.set_layer_size(input_layer.size)
psize = self.calc_parameter_size(image_conf) psize = self.calc_parameter_size(image_conf)
dims = [1, psize] dims = [1, psize]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册