未验证 提交 bc90916e 编写于 作者: Y Yiqun Liu 提交者: GitHub

Do not define and save reserve_space for inference. (#32375)

上级 4be3b057
......@@ -223,24 +223,27 @@ def batch_norm(x,
helper = LayerHelper('batch_norm', **locals())
dtype = x.dtype if x.dtype is not 'float16' else 'float32'
param_dtype = x.dtype if x.dtype is not 'float16' else 'float32'
saved_mean = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
dtype=param_dtype, stop_gradient=True)
saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
batch_norm_out = helper.create_variable_for_type_inference(dtype)
reserve_space = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
dtype=param_dtype, stop_gradient=True)
batch_norm_out = helper.create_variable_for_type_inference(x.dtype)
outputs = {
"Y": [batch_norm_out],
"MeanOut": [running_mean],
"VarianceOut": [running_var],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance],
"ReserveSpace": [reserve_space]
"SavedVariance": [saved_variance]
}
if training or trainable_statistics:
# reserve_space is only used for training.
reserve_space = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
outputs["ReserveSpace"] = [reserve_space]
helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册