提交 54a8c04f 编写于 作者: Y Yang Yang

add inplace attr to bn

上级 25317bd3
...@@ -1483,6 +1483,7 @@ def batch_norm(input, ...@@ -1483,6 +1483,7 @@ def batch_norm(input,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
data_layout='NCHW', data_layout='NCHW',
in_place=False,
name=None, name=None,
moving_mean_name=None, moving_mean_name=None,
moving_variance_name=None): moving_variance_name=None):
...@@ -1538,7 +1539,7 @@ def batch_norm(input, ...@@ -1538,7 +1539,7 @@ def batch_norm(input,
saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
batch_norm_out = helper.create_tmp_variable(dtype) batch_norm_out = input if in_place else helper.create_tmp_variable(dtype)
helper.append_op( helper.append_op(
type="batch_norm", type="batch_norm",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册