diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 11dc84ae20679bb73735f9119739fca5ea7fa673..2a6b6d5e2bf86e67ddad0499743bcc06260924f2 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2368,6 +2368,7 @@ class BatchNormLayer(LayerBase): use_global_stats=True, moving_average_fraction=0.9, batch_norm_type=None, + mean_var_names=None, **xargs): if inputs is None: inputs = [] @@ -2421,6 +2422,11 @@ class BatchNormLayer(LayerBase): psize = self.calc_parameter_size(image_conf) dims = [1, psize] + if mean_var_names is not None: + assert len(mean_var_names) == 2 + self.inputs[1].parameter_name = mean_var_names[0] + self.inputs[2].parameter_name = mean_var_names[1] + self.create_input_parameter(0, psize) self.create_input_parameter(1, psize, dims) self.create_input_parameter(2, psize, dims) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index cba45bd3afa178ab4dd3a50f0947b144e7466e53..e1703c158a599fe14740cf89062e9792ce76dab2 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2959,7 +2959,8 @@ def batch_norm_layer(input, layer_attr=None, batch_norm_type=None, moving_average_fraction=0.9, - use_global_stats=None): + use_global_stats=None, + mean_var_names=None): """ Batch Normalization Layer. The notation of this layer as follow. @@ -3026,6 +3027,8 @@ def batch_norm_layer(input, :math:`runningMean = newMean*(1-factor) + runningMean*factor` :type moving_average_fraction: float. + :param mean_var_names: [mean name, variance name] + :type mean_var_names: string list :return: LayerOutput object. :rtype: LayerOutput """ @@ -3047,6 +3050,7 @@ def batch_norm_layer(input, bias=ParamAttr.to_bias(bias_attr), moving_average_fraction=moving_average_fraction, use_global_stats=use_global_stats, + mean_var_names=mean_var_names, **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput(