提交 aea05b6e 编写于 作者: C chengduoZH

add param mean_var_names

上级 26638e9c
...@@ -2366,6 +2366,7 @@ class BatchNormLayer(LayerBase): ...@@ -2366,6 +2366,7 @@ class BatchNormLayer(LayerBase):
inputs, inputs,
bias=True, bias=True,
use_global_stats=True, use_global_stats=True,
mean_var_names=None,
moving_average_fraction=0.9, moving_average_fraction=0.9,
batch_norm_type=None, batch_norm_type=None,
**xargs): **xargs):
...@@ -2421,11 +2422,11 @@ class BatchNormLayer(LayerBase): ...@@ -2421,11 +2422,11 @@ class BatchNormLayer(LayerBase):
psize = self.calc_parameter_size(image_conf) psize = self.calc_parameter_size(image_conf)
dims = [1, psize] dims = [1, psize]
if mean_var_names is not None:
assert len(mean_var_names) == 2
self.inputs[1].parameter_name = mean_var_names[0]
self.inputs[2].parameter_name = mean_var_names[1]
self.inputs[1].parameter_name = self.inputs[0].parameter_name.split('.')[0] + '.' + \
self.inputs[1].parameter_name.split('.')[1]
self.inputs[2].parameter_name = self.inputs[0].parameter_name.split('.')[0] + '.' + \
self.inputs[2].parameter_name.split('.')[1]
self.create_input_parameter(0, psize) self.create_input_parameter(0, psize)
self.create_input_parameter(1, psize, dims) self.create_input_parameter(1, psize, dims)
self.create_input_parameter(2, psize, dims) self.create_input_parameter(2, psize, dims)
......
...@@ -2957,6 +2957,7 @@ def batch_norm_layer(input, ...@@ -2957,6 +2957,7 @@ def batch_norm_layer(input,
bias_attr=None, bias_attr=None,
param_attr=None, param_attr=None,
layer_attr=None, layer_attr=None,
mean_var_names=None,
batch_norm_type=None, batch_norm_type=None,
moving_average_fraction=0.9, moving_average_fraction=0.9,
use_global_stats=None): use_global_stats=None):
...@@ -3014,6 +3015,8 @@ def batch_norm_layer(input, ...@@ -3014,6 +3015,8 @@ def batch_norm_layer(input,
:type param_attr: ParameterAttribute :type param_attr: ParameterAttribute
:param layer_attr: Extra Layer Attribute. :param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute :type layer_attr: ExtraLayerAttribute
:param mean_var_names: [mean name, variance name]
:type mean_var_names: string list
:param use_global_stats: whether use moving mean/variance statistics :param use_global_stats: whether use moving mean/variance statistics
during testing peroid. If None or True, during testing peroid. If None or True,
it will use moving mean/variance statistics during it will use moving mean/variance statistics during
...@@ -3044,6 +3047,7 @@ def batch_norm_layer(input, ...@@ -3044,6 +3047,7 @@ def batch_norm_layer(input,
active_type=act.name, active_type=act.name,
type=LayerType.BATCH_NORM_LAYER, type=LayerType.BATCH_NORM_LAYER,
batch_norm_type=batch_norm_type, batch_norm_type=batch_norm_type,
mean_var_names=mean_var_names,
bias=ParamAttr.to_bias(bias_attr), bias=ParamAttr.to_bias(bias_attr),
moving_average_fraction=moving_average_fraction, moving_average_fraction=moving_average_fraction,
use_global_stats=use_global_stats, use_global_stats=use_global_stats,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册