提交 bbf3b47b 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #966 from qingqing01/batch_norm

Fix bug in config_parse.py when batch_norm layer is used in RecurrentLayerGroup
...@@ -498,9 +498,16 @@ class Input(Cfg): ...@@ -498,9 +498,16 @@ class Input(Cfg):
is_static=None, is_static=None,
is_shared=None, is_shared=None,
update_hooks=None, update_hooks=None,
input_layer_argument=None, ): input_layer_argument=None,
make_layer_name_in_submodel=True, ):
"""
@param make_layer_name_in_submodel True by defalut, you might need to
set it carefully when adding Input in config_parser.py.
"""
self.add_keys(locals()) self.add_keys(locals())
self.input_layer_name = MakeLayerNameInSubmodel(input_layer_name) self.input_layer_name = MakeLayerNameInSubmodel(
input_layer_name
) if make_layer_name_in_submodel else input_layer_name
# Define a projection for iexed layer # Define a projection for iexed layer
...@@ -1848,7 +1855,8 @@ class BatchNormLayer(LayerBase): ...@@ -1848,7 +1855,8 @@ class BatchNormLayer(LayerBase):
initial_std=0.0, initial_std=0.0,
initial_mean=0.0, initial_mean=0.0,
is_static=True, is_static=True,
is_shared=is_shared, )) is_shared=is_shared,
make_layer_name_in_submodel=False, ))
parallel_nn = bool(int(g_command_config_args.get("parallel_nn", 0))) parallel_nn = bool(int(g_command_config_args.get("parallel_nn", 0)))
cudnn_version = int(g_command_config_args.get("cudnn_version", 0)) cudnn_version = int(g_command_config_args.get("cudnn_version", 0))
...@@ -1880,7 +1888,7 @@ class BatchNormLayer(LayerBase): ...@@ -1880,7 +1888,7 @@ class BatchNormLayer(LayerBase):
# when either of it is non-zero. # when either of it is non-zero.
if input_layer.width != 0 or input_layer.height != 0: if input_layer.width != 0 or input_layer.height != 0:
self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size, self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
image_conf.channels, True) image_conf.channels, False)
else: else:
self.set_layer_size(input_layer.size) self.set_layer_size(input_layer.size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册