未验证 提交 50a6e7c5 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #8534 from JiayiFeng/fix_bn_status

Change BN's 'mean' and 'variance' from persistable variable to Parameter
...@@ -1519,21 +1519,21 @@ def batch_norm(input, ...@@ -1519,21 +1519,21 @@ def batch_norm(input,
bias = helper.create_parameter( bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
mean = helper.create_global_variable( mean = helper.create_parameter(
name=moving_mean_name, attr=ParamAttr(
dtype=input.dtype, name=moving_mean_name, initializer=Constant(0.0), trainable=False),
shape=param_shape, shape=param_shape,
persistable=True, dtype=input.dtype)
stop_gradient=True) mean.stop_gradient = True
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
variance = helper.create_global_variable( variance = helper.create_parameter(
name=moving_variance_name, attr=ParamAttr(
dtype=input.dtype, name=moving_variance_name,
initializer=Constant(1.0),
trainable=False),
shape=param_shape, shape=param_shape,
persistable=True, dtype=input.dtype)
stop_gradient=True) variance.stop_gradient = True
helper.set_variable_initializer(var=variance, initializer=Constant(1.0))
# create output # create output
# mean and mean_out share the same memory # mean and mean_out share the same memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册