From 54a8c04fab9310ef78f0b000ae411fd7ae706ee7 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Tue, 27 Mar 2018 22:09:43 +0000 Subject: [PATCH] add inplace attr to bn --- python/paddle/fluid/layers/nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 2db4e5d27d4..0332556f62c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1483,6 +1483,7 @@ def batch_norm(input, param_attr=None, bias_attr=None, data_layout='NCHW', + in_place=False, name=None, moving_mean_name=None, moving_variance_name=None): @@ -1538,7 +1539,7 @@ def batch_norm(input, saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - batch_norm_out = helper.create_tmp_variable(dtype) + batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) helper.append_op( type="batch_norm", -- GitLab