diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 54824233f70762a932c81d4c07814e004affd3cb..e6971b3781c3ba386edef2fcfe05385d5fc0ae47 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -223,24 +223,27 @@ def batch_norm(x, helper = LayerHelper('batch_norm', **locals()) - dtype = x.dtype if x.dtype is not 'float16' else 'float32' + param_dtype = x.dtype if x.dtype is not 'float16' else 'float32' saved_mean = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) + dtype=param_dtype, stop_gradient=True) saved_variance = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) - batch_norm_out = helper.create_variable_for_type_inference(dtype) - reserve_space = helper.create_variable_for_type_inference( - dtype=x.dtype, stop_gradient=True) + dtype=param_dtype, stop_gradient=True) + batch_norm_out = helper.create_variable_for_type_inference(x.dtype) outputs = { "Y": [batch_norm_out], "MeanOut": [running_mean], "VarianceOut": [running_var], "SavedMean": [saved_mean], - "SavedVariance": [saved_variance], - "ReserveSpace": [reserve_space] + "SavedVariance": [saved_variance] } + if training or trainable_statistics: + # reserve_space is only used for training. + reserve_space = helper.create_variable_for_type_inference( + dtype=x.dtype, stop_gradient=True) + outputs["ReserveSpace"] = [reserve_space] + helper.append_op( type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)