提交 a225bfa3 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #3982 from PaddlePaddle/fix_batch_norm_parameter_share

fix batch_norm parameter share
...@@ -2368,6 +2368,7 @@ class BatchNormLayer(LayerBase): ...@@ -2368,6 +2368,7 @@ class BatchNormLayer(LayerBase):
use_global_stats=True, use_global_stats=True,
moving_average_fraction=0.9, moving_average_fraction=0.9,
batch_norm_type=None, batch_norm_type=None,
mean_var_names=None,
**xargs): **xargs):
if inputs is None: if inputs is None:
inputs = [] inputs = []
...@@ -2421,6 +2422,11 @@ class BatchNormLayer(LayerBase): ...@@ -2421,6 +2422,11 @@ class BatchNormLayer(LayerBase):
psize = self.calc_parameter_size(image_conf) psize = self.calc_parameter_size(image_conf)
dims = [1, psize] dims = [1, psize]
if mean_var_names is not None:
assert len(mean_var_names) == 2
self.inputs[1].parameter_name = mean_var_names[0]
self.inputs[2].parameter_name = mean_var_names[1]
self.create_input_parameter(0, psize) self.create_input_parameter(0, psize)
self.create_input_parameter(1, psize, dims) self.create_input_parameter(1, psize, dims)
self.create_input_parameter(2, psize, dims) self.create_input_parameter(2, psize, dims)
......
...@@ -2959,7 +2959,8 @@ def batch_norm_layer(input, ...@@ -2959,7 +2959,8 @@ def batch_norm_layer(input,
layer_attr=None, layer_attr=None,
batch_norm_type=None, batch_norm_type=None,
moving_average_fraction=0.9, moving_average_fraction=0.9,
use_global_stats=None): use_global_stats=None,
mean_var_names=None):
""" """
Batch Normalization Layer. The notation of this layer as follow. Batch Normalization Layer. The notation of this layer as follow.
...@@ -3026,6 +3027,8 @@ def batch_norm_layer(input, ...@@ -3026,6 +3027,8 @@ def batch_norm_layer(input,
:math:`runningMean = newMean*(1-factor) :math:`runningMean = newMean*(1-factor)
+ runningMean*factor` + runningMean*factor`
:type moving_average_fraction: float. :type moving_average_fraction: float.
:param mean_var_names: [mean name, variance name]
:type mean_var_names: string list
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
...@@ -3047,6 +3050,7 @@ def batch_norm_layer(input, ...@@ -3047,6 +3050,7 @@ def batch_norm_layer(input,
bias=ParamAttr.to_bias(bias_attr), bias=ParamAttr.to_bias(bias_attr),
moving_average_fraction=moving_average_fraction, moving_average_fraction=moving_average_fraction,
use_global_stats=use_global_stats, use_global_stats=use_global_stats,
mean_var_names=mean_var_names,
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput( return LayerOutput(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册