未验证 提交 21565e8d 编写于 作者: Z zhang wenhui 提交者: GitHub

fix batchnorm ,test=develop (#26972)

上级 4204ceae
......@@ -176,13 +176,12 @@ def batch_norm(x,
mean_out = running_mean
variance_out = running_var
true_data_format = ['NC', 'NCL', 'NCHW', 'NCWH', 'NCDHW']
true_data_format = ['NC', 'NCL', 'NCHW', 'NCDHW']
if data_format not in true_data_format:
raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCWH', 'NCDHW', but receive {}".
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', but receive {}".
format(data_format))
if data_format != 'NCWH':
data_format = 'NCHW'
if in_dygraph_mode():
......
......@@ -811,7 +811,7 @@ class BatchNorm2d(_BatchNormBase):
If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable.
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW.
data_format(str, optional): Specify the input data format, the data format can be "NCHW". Default: NCHW.
track_running_stats(bool, optional): Whether to use global mean and variance. In train period,
True will track global mean and variance used for inference. When inference, track_running_stats must be
True. Default: True.
......@@ -844,10 +844,10 @@ class BatchNorm2d(_BatchNormBase):
"""
def _check_data_format(self, input):
if input == 'NCHW' or input == 'NCWH':
if input == 'NCHW':
self._data_format = input
else:
raise ValueError('expected NCHW or NCWH for data_format input')
raise ValueError('expected NCHW for data_format input')
def _check_input_dim(self, input):
if len(input.shape) != 4:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册