未验证 提交 482d297b 编写于 作者: Q qingqing01 提交者: GitHub

Disable in_place in batch_norm API. (#12736) (#12875)

* Disable in_place in batch_norm API.
上级 f1bf1334
...@@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Variance", AddInput("Variance",
"The global variance (for training) " "The global variance (for training) "
"or estimated Variance (for testing)"); "or estimated Variance (for testing)");
AddOutput("Y", "result after normalization").Reuse("X"); AddOutput("Y", "result after normalization");
AddOutput("MeanOut", AddOutput("MeanOut",
"Share memory with Mean. " "Share memory with Mean. "
"Store the global mean when training") "Store the global mean when training")
......
...@@ -27,6 +27,7 @@ from . import utils ...@@ -27,6 +27,7 @@ from . import utils
import random import random
from .. import unique_name from .. import unique_name
from functools import reduce from functools import reduce
import warnings
__all__ = [ __all__ = [
'fc', 'fc',
...@@ -2045,7 +2046,7 @@ def batch_norm(input, ...@@ -2045,7 +2046,7 @@ def batch_norm(input,
param_attr(ParamAttr): The parameter attribute for Parameter `scale`. param_attr(ParamAttr): The parameter attribute for Parameter `scale`.
bias_attr(ParamAttr): The parameter attribute for Parameter `bias`. bias_attr(ParamAttr): The parameter attribute for Parameter `bias`.
data_layout(string, default NCHW): NCHW|NHWC data_layout(string, default NCHW): NCHW|NHWC
in_place(bool, Default False): Make the input and output of batch norm reuse memory. in_place(bool, Default False): This argument is deprecated since 0.15.0.
use_mkldnn(bool, Default false): ${use_mkldnn_comment} use_mkldnn(bool, Default false): ${use_mkldnn_comment}
name(string, Default None): A name for this layer(optional). If set None, the layer name(string, Default None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
...@@ -2067,6 +2068,10 @@ def batch_norm(input, ...@@ -2067,6 +2068,10 @@ def batch_norm(input,
helper = LayerHelper('batch_norm', **locals()) helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
if in_place:
raise warnings.warn("The argument in_place is deprecated since 0.15.0, "
"please do not set it True.")
input_shape = input.shape input_shape = input.shape
if data_layout == 'NCHW': if data_layout == 'NCHW':
channel_num = input_shape[1] channel_num = input_shape[1]
...@@ -2116,7 +2121,7 @@ def batch_norm(input, ...@@ -2116,7 +2121,7 @@ def batch_norm(input,
saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) batch_norm_out = helper.create_tmp_variable(dtype)
helper.append_op( helper.append_op(
type="batch_norm", type="batch_norm",
......
...@@ -229,7 +229,7 @@ def img_conv_group(input, ...@@ -229,7 +229,7 @@ def img_conv_group(input,
use_mkldnn=use_mkldnn) use_mkldnn=use_mkldnn)
if conv_with_batchnorm[i]: if conv_with_batchnorm[i]:
tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True) tmp = layers.batch_norm(input=tmp, act=conv_act)
drop_rate = conv_batchnorm_drop_rate[i] drop_rate = conv_batchnorm_drop_rate[i]
if abs(drop_rate) > 1e-5: if abs(drop_rate) > 1e-5:
tmp = layers.dropout(x=tmp, dropout_prob=drop_rate) tmp = layers.dropout(x=tmp, dropout_prob=drop_rate)
......
...@@ -256,6 +256,9 @@ def main(net_type, use_cuda, is_local=True): ...@@ -256,6 +256,9 @@ def main(net_type, use_cuda, is_local=True):
save_dirname = "image_classification_" + net_type + ".inference.model" save_dirname = "image_classification_" + net_type + ".inference.model"
train(net_type, use_cuda, save_dirname, is_local) train(net_type, use_cuda, save_dirname, is_local)
# There is bug in fluid.InferenceTranspiler for VGG.
if net_type == "resnet":
infer(use_cuda, save_dirname) infer(use_cuda, save_dirname)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册