diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index f63fc33525576c61998be9facf32b4c66aa2a971..9e8f365f6d23a95275b9a696f6088bb287108ec0 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -176,14 +176,13 @@ def batch_norm(x, mean_out = running_mean variance_out = running_var - true_data_format = ['NC', 'NCL', 'NCHW', 'NCWH', 'NCDHW'] + true_data_format = ['NC', 'NCL', 'NCHW', 'NCDHW'] if data_format not in true_data_format: raise ValueError( - "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCWH', 'NCDHW', but receive {}". + "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', but receive {}". format(data_format)) - if data_format != 'NCWH': - data_format = 'NCHW' + data_format = 'NCHW' if in_dygraph_mode(): # for dygraph need tuple diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 8bdb09c76918d747d644a8c781dabb6aab41522c..d13bf66ba5bfe483284e78dbcd2a42f8f3397210 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -811,7 +811,7 @@ class BatchNorm2d(_BatchNormBase): If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. + data_format(str, optional): Specify the input data format, the data format can be "NCHW". Default: NCHW. track_running_stats(bool, optional): Whether to use global mean and variance. In train period, True will track global mean and variance used for inference. When inference, track_running_stats must be True. Default: True. @@ -844,10 +844,10 @@ class BatchNorm2d(_BatchNormBase): """ def _check_data_format(self, input): - if input == 'NCHW' or input == 'NCWH': + if input == 'NCHW': self._data_format = input else: - raise ValueError('expected NCHW or NCWH for data_format input') + raise ValueError('expected NCHW for data_format input') def _check_input_dim(self, input): if len(input.shape) != 4: