diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 2db4e5d27d40e902e2da17b3748c64a1eb3d052b..0332556f62c46b187bd79841e4969d9da08b57a5 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1483,6 +1483,7 @@ def batch_norm(input, param_attr=None, bias_attr=None, data_layout='NCHW', + in_place=False, name=None, moving_mean_name=None, moving_variance_name=None): @@ -1538,7 +1539,7 @@ def batch_norm(input, saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - batch_norm_out = helper.create_tmp_variable(dtype) + batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) helper.append_op( type="batch_norm",